Skip to content

Commit

Permalink
Addressed wcourtney's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
verult committed Jul 26, 2023
1 parent 96642fa commit 861752e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 37 deletions.
81 changes: 58 additions & 23 deletions cirq-google/cirq_google/engine/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, message: str):


class ResponseDemux:
"""A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.
"""An event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.
A caller can subscribe to the response matching a provided message ID. Only a single caller may
subscribe to each ID.
Expand Down Expand Up @@ -100,15 +100,19 @@ class StreamManager:
The main manager method is `submit()`, which sends the provided job to Quantum Engine through
the stream and returns a future to be completed when either the result is ready or the job has
failed.
failed. The submitted job can also be cancelled by calling `cancel()` on the future returned by
`submit()`.
A new stream is opened during the first `submit()` call, and it stays open. If the stream is
unused, users could close the stream and free management resources by calling `stop()`.
unused, users can close the stream and free management resources by calling `stop()`.
"""

def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
self._request_queue: Optional[asyncio.Queue] = None
# Used to determine whether the stream coroutine is actively running, and provides a way to
# cancel it.
self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None
# TODO(#5996) consider making the scope of response futures local to the relevant tasks
# rather than all of StreamManager.
Expand All @@ -123,9 +127,12 @@ def submit(
If submit() is called for the first time since StreamManager instantiation or since the last
time stop() was called, it will create a new long-running stream.
The job can be cancelled by calling `cancel()` on the returned future.
Args:
project_name: The full project ID resource path associated with the job.
program: The Quantum Engine program representing the circuit to be executed.
program: The Quantum Engine program representing the circuit to be executed. The program
name must be set.
job: The Quantum Engine job to be executed.
Returns:
Expand All @@ -134,27 +141,35 @@ def submit(
Raises:
ProgramAlreadyExistsError if the program already exists.
StreamError if there is a non-retryable error while executing the job.
ValueError if program name is not set.
concurrent.futures.CancelledError if the stream is stopped while a job is in flight.
google.api_core.exceptions.GoogleAPICallError if the stream breaks with a non-retryable
error.
"""
if 'name' not in program:
raise ValueError('Program name must be set.')

if self._manage_stream_loop_future is None:
self._manage_stream_loop_future = self._executor.submit(self._manage_stream)
self._manage_stream_loop_future.add_done_callback(self._manage_stream_cancel())
return self._executor.submit(self._manage_execution, project_name, program, job)

def stop(self) -> None:
"""Closes the open stream and resets all management resources."""
if self._manage_stream_loop_future is not None:
self._manage_stream_loop_future.cancel()
self._reset()
try:
if (
self._manage_stream_loop_future is not None
and not self._manage_stream_loop_future.done()
):
self._manage_stream_loop_future.cancel()
finally:
self._reset()

def _reset(self):
"""Resets the manager state."""
self._request_queue = None
self._manage_stream_loop_future = None
self._response_demux = ResponseDemux()
self._next_available_message_id = 0

@property
def _executor(self) -> AsyncioExecutor:
Expand Down Expand Up @@ -218,15 +233,6 @@ async def _manage_execution(
parent=project_name, quantum_program=program, quantum_job=job
),
)
create_job_request = quantum.QuantumRunStreamRequest(
parent=project_name,
create_quantum_job=quantum.CreateQuantumJobRequest(
parent=program.name, quantum_job=job
),
)
get_result_request = quantum.QuantumRunStreamRequest(
parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name)
)

while self._request_queue is None:
# Wait for the stream coroutine to start.
Expand All @@ -244,15 +250,17 @@ async def _manage_execution(
except google_exceptions.GoogleAPICallError as e:
if _is_retryable_error(e):
# Retry
current_request = get_result_request
current_request = _to_get_result_request(create_program_and_job_request)
continue
# TODO(#5996) add exponential backoff
raise e

# Either when this request is canceled or the _manage_stream() loop is canceled.
except asyncio.CancelledError:
# TODO(#5996) Consider moving this logic into a future done callback, so that the
# the cancellation caller can wait for it to complete.
# TODO(#5996) Consider moving the request future cancellation logic into a future
# done callback, so that the the cancellation caller can wait for it to complete.
# TODO(#5996) Check the condition that response_future is not done before
# cancelling, once request cancellation is moved to a callback.
if response_future is not None:
response_future.cancel()
await self._cancel(job.name)
Expand All @@ -264,11 +272,11 @@ async def _manage_execution(
elif 'job' in response:
return response.job
elif 'error' in response:
current_request = _decide_retry_request_or_raise(
current_request = _get_retry_request_or_raise(
response.error,
current_request,
create_program_and_job_request,
create_job_request,
_to_create_job_request(create_program_and_job_request),
)
continue
else:
Expand All @@ -286,7 +294,7 @@ def _generate_message_id(self) -> str:
return message_id


def _decide_retry_request_or_raise(
def _get_retry_request_or_raise(
error: quantum.StreamError,
current_request,
create_program_and_job_request,
Expand Down Expand Up @@ -332,3 +340,30 @@ async def _request_iterator(
"""
while True:
yield await request_queue.get()


def _to_create_job_request(
create_program_and_job_request: quantum.QuantumRunStreamRequest,
) -> quantum.QuantumRunStreamRequest:
"""Converted the QuantumRunStreamRequest from a CreateQuantumProgramAndJobRequest to a
CreateQuantumJobRequest.
"""
program = create_program_and_job_request.create_quantum_program_and_job.quantum_program
job = create_program_and_job_request.create_quantum_program_and_job.quantum_job
return quantum.QuantumRunStreamRequest(
parent=create_program_and_job_request.parent,
create_quantum_job=quantum.CreateQuantumJobRequest(parent=program.name, quantum_job=job),
)


def _to_get_result_request(
create_program_and_job_request: quantum.QuantumRunStreamRequest,
) -> quantum.QuantumRunStreamRequest:
"""Converted the QuantumRunStreamRequest from a CreateQuantumProgramAndJobRequest to a
GetQuantumResultRequest.
"""
job = create_program_and_job_request.create_quantum_program_and_job.quantum_job
return quantum.QuantumRunStreamRequest(
parent=create_program_and_job_request.parent,
get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name),
)
52 changes: 38 additions & 14 deletions cirq-google/cirq_google/engine/stream_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import google.api_core.exceptions as google_exceptions

from cirq_google.engine.stream_manager import (
_decide_retry_request_or_raise,
_get_retry_request_or_raise,
ProgramAlreadyExistsError,
ResponseDemux,
StreamError,
Expand Down Expand Up @@ -206,7 +206,7 @@ async def test_publish_exception_after_publishing_response_does_not_change_futur

class TestStreamManager:
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_expects_result_response(self, client_constructor):
def test_submit_expects_result_response(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
# Arrange
Expand Down Expand Up @@ -234,7 +234,29 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_cancel_future_expects_engine_cancellation_rpc_call(self, client_constructor):
def test_submit_program_without_name_raises(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
# Arrange
expected_result = quantum.QuantumResult(
parent='projects/proj/programs/prog/jobs/job0'
)
mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)]
fake_client = setup_fake_quantum_run_stream_client(
client_constructor, responses_and_exceptions=mock_responses
)
manager = StreamManager(fake_client)

with pytest.raises(ValueError, match='Program name must be set'):
await manager.submit(
REQUEST_PROJECT_NAME, quantum.QuantumProgram(), REQUEST_JOB
)
manager.stop()

duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_submit_cancel_future_expects_engine_cancellation_rpc_call(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
fake_client = setup_fake_quantum_run_stream_client(
Expand All @@ -255,7 +277,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_stream_broken_twice_expects_retry_with_get_quantum_result_twice(
def test_submit_stream_broken_twice_expects_retry_with_get_quantum_result_twice(
self, client_constructor
):
async def test():
Expand Down Expand Up @@ -294,7 +316,7 @@ async def test():
],
)
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_with_retryable_stream_breakage_expects_get_result_request(
def test_submit_with_retryable_stream_breakage_expects_get_result_request(
self, client_constructor, error
):
async def test():
Expand Down Expand Up @@ -335,7 +357,9 @@ async def test():
],
)
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_with_non_retryable_stream_breakage_raises_error(self, client_constructor, error):
def test_submit_with_non_retryable_stream_breakage_raises_error(
self, client_constructor, error
):
async def test():
async with duet.timeout_scope(5):
mock_responses = [
Expand All @@ -359,7 +383,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_expects_job_response(self, client_constructor):
def test_submit_expects_job_response(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
expected_job = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0')
Expand All @@ -382,7 +406,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_job_does_not_exist_expects_create_quantum_job_request(self, client_constructor):
def test_submit_job_does_not_exist_expects_create_quantum_job_request(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
expected_result = quantum.QuantumResult(
Expand Down Expand Up @@ -414,7 +438,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_program_does_not_exist_expects_create_quantum_program_and_job_request(
def test_submit_program_does_not_exist_expects_create_quantum_program_and_job_request(
self, client_constructor
):
async def test():
Expand Down Expand Up @@ -454,7 +478,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_program_already_exists_expects_program_already_exists_error(
def test_submit_program_already_exists_expects_program_already_exists_error(
self, client_constructor
):
async def test():
Expand All @@ -478,7 +502,7 @@ async def test():
duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_send_twice_in_parallel_expect_result_responses(self, client_constructor):
def test_submit_twice_in_parallel_expect_result_responses(self, client_constructor):
async def test():
async with duet.timeout_scope(5):
request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1')
Expand Down Expand Up @@ -517,7 +541,7 @@ async def test():

# TODO(#5996) Update fake client implementation to support this test case.
# @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
# def test_send_twice_and_break_stream_expect_result_responses(self, client_constructor):
# def test_submit_twice_and_break_stream_expect_result_responses(self, client_constructor):
# async def test():
# async with duet.timeout_scope(5):
# request_job1 = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job1')
Expand Down Expand Up @@ -621,7 +645,7 @@ async def test():
(Code.JOB_DOES_NOT_EXIST, 'create_quantum_job'),
],
)
def test_decide_retry_request_or_raise_expects_stream_error(
def test_get_retry_request_or_raise_expects_stream_error(
self, error_code, current_request_type
):
# This tests a private function, but it's much easier to exhaustively test this function
Expand All @@ -644,7 +668,7 @@ def test_decide_retry_request_or_raise_expects_stream_error(
current_request = get_quantum_result_request

with pytest.raises(StreamError):
_decide_retry_request_or_raise(
_get_retry_request_or_raise(
quantum.StreamError(code=error_code),
current_request,
create_quantum_program_and_job_request,
Expand Down

0 comments on commit 861752e

Please sign in to comment.