Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamManager: Add mechanism to close the request iterator #6263

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 48 additions & 23 deletions cirq-google/cirq_google/engine/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ class StreamManager:

def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
# TODO(#5996) Make this local to the asyncio thread.
self._request_queue: Optional[asyncio.Queue] = None
# Used to determine whether the stream coroutine is actively running, and provides a way to
# cancel it.
self._manage_stream_loop_future: Optional[duet.AwaitableFuture[None]] = None
Expand All @@ -121,6 +119,16 @@ def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
# interface.
self._response_demux = ResponseDemux()
self._next_available_message_id = 0
# Construct queue in AsyncioExecutor to ensure it binds to the correct event loop, since it
# is used by asyncio coroutines.
self._request_queue = self._executor.submit(self._make_request_queue).result()

async def _make_request_queue(self) -> asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]:
"""Returns a queue used to back the request iterator passed to the stream.

If `None` is put into the queue, the request iterator will stop.
"""
return asyncio.Queue()

def submit(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
Expand Down Expand Up @@ -153,8 +161,12 @@ def submit(
raise ValueError('Program name must be set.')

if self._manage_stream_loop_future is None or self._manage_stream_loop_future.done():
self._manage_stream_loop_future = self._executor.submit(self._manage_stream)
return self._executor.submit(self._manage_execution, project_name, program, job)
self._manage_stream_loop_future = self._executor.submit(
self._manage_stream, self._request_queue
)
return self._executor.submit(
self._manage_execution, self._request_queue, project_name, program, job
)

def stop(self) -> None:
"""Closes the open stream and resets all management resources."""
Expand All @@ -168,17 +180,19 @@ def stop(self) -> None:

def _reset(self):
"""Resets the manager state."""
self._request_queue = None
self._manage_stream_loop_future = None
self._response_demux = ResponseDemux()
self._request_queue = self._executor.submit(self._make_request_queue).result()

@property
def _executor(self) -> AsyncioExecutor:
# We must re-use a single Executor due to multi-threading issues in gRPC
# clients: https://github.com/grpc/grpc/issues/25364.
return AsyncioExecutor.instance()

async def _manage_stream(self) -> None:
async def _manage_stream(
self, request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]]
) -> None:
"""The stream coroutine, an asyncio coroutine to manage QuantumRunStream.

This coroutine reads responses from the stream and forwards them to the ResponseDemux, where
Expand All @@ -187,25 +201,32 @@ async def _manage_stream(self) -> None:
When the stream breaks, the stream is reopened, and all execution coroutines are notified.

There is at most a single instance of this coroutine running.

Args:
request_queue: The queue holding requests from the execution coroutine.
"""
self._request_queue = asyncio.Queue()
while True:
try:
# The default gRPC client timeout is used.
response_iterable = await self._grpc_client.quantum_run_stream(
_request_iterator(self._request_queue)
_request_iterator(request_queue)
)
async for response in response_iterable:
self._response_demux.publish(response)
except asyncio.CancelledError:
await request_queue.put(None)
break
except BaseException as e:
# TODO(#5996) Close the request iterator to close the existing stream.
# Note: the message ID counter is not reset upon a new stream.
await request_queue.put(None)
self._response_demux.publish_exception(e) # Raise to all request tasks

async def _manage_execution(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
self,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
project_name: str,
program: quantum.QuantumProgram,
job: quantum.QuantumJob,
) -> Union[quantum.QuantumResult, quantum.QuantumJob]:
"""The execution coroutine, an asyncio coroutine to manage the lifecycle of a job execution.

Expand All @@ -216,28 +237,33 @@ async def _manage_execution(
error by sending another request. The exact request type depends on the error.

There is one execution coroutine per running job submission.

Args:
request_queue: The queue used to send requests to the stream coroutine.
project_name: The full project ID resource path associated with the job.
program: The Quantum Engine program representing the circuit to be executed.
job: The Quantum Engine job to be executed.

Raises:
concurrent.futures.CancelledError: if either the request is cancelled or the stream
coroutine is cancelled.
google.api_core.exceptions.GoogleAPICallError: if the stream breaks with a non-retryable
error.
ValueError: if the response is of a type which is not recognized by this client.
"""
# Construct requests ahead of time to be reused for retries.
create_program_and_job_request = quantum.QuantumRunStreamRequest(
parent=project_name,
create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest(
parent=project_name, quantum_program=program, quantum_job=job
),
)

while self._request_queue is None:
# Wait for the stream coroutine to start.
# Ignoring coverage since this is rarely triggered.
# TODO(#5996) Consider awaiting for the queue to become available, once it is changed
# to be local to the asyncio thread.
await asyncio.sleep(1) # pragma: no cover

current_request = create_program_and_job_request
while True:
try:
current_request.message_id = self._generate_message_id()
response_future = self._response_demux.subscribe(current_request.message_id)
await self._request_queue.put(current_request)
await request_queue.put(current_request)
response = await response_future

# Broken stream
Expand Down Expand Up @@ -325,16 +351,15 @@ def _is_retryable_error(e: google_exceptions.GoogleAPICallError) -> bool:
return any(isinstance(e, exception_type) for exception_type in RETRYABLE_GOOGLE_API_EXCEPTIONS)


# TODO(#5996) Add stop signal to the request iterator.
async def _request_iterator(
request_queue: asyncio.Queue,
request_queue: asyncio.Queue[Optional[quantum.QuantumRunStreamRequest]],
) -> AsyncIterator[quantum.QuantumRunStreamRequest]:
"""The request iterator for Quantum Engine client RPC quantum_run_stream().

Every call to this method generates a new iterator.
"""
while True:
yield await request_queue.get()
while request := await request_queue.get():
yield request


def _to_create_job_request(
Expand Down
127 changes: 116 additions & 11 deletions cirq-google/cirq_google/engine/stream_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,26 @@ def setup(client_constructor):
class FakeQuantumRunStream:
"""A fake Quantum Engine client which supports QuantumRunStream and CancelQuantumJob."""

_REQUEST_STOPPED = 'REQUEST_STOPPED'

def __init__(self):
self.all_stream_requests: List[quantum.QuantumRunStreamRequest] = []
self.all_cancel_requests: List[quantum.CancelQuantumJobRequest] = []
self._executor = AsyncioExecutor.instance()
self._request_buffer = duet.AsyncCollector[quantum.QuantumRunStreamRequest]()
self._request_iterator_stopped = duet.AwaitableFuture()
# asyncio.Queue needs to be initialized inside the asyncio thread because all callers need
# to use the same event loop.
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
self._responses_and_exceptions_future: duet.AwaitableFuture[
asyncio.Queue[Union[quantum.QuantumRunStreamResponse, BaseException]]
] = duet.AwaitableFuture()

async def quantum_run_stream(
self, requests: AsyncIterator[quantum.QuantumRunStreamRequest], **kwargs
) -> Awaitable[AsyncIterable[quantum.QuantumRunStreamResponse]]:
"""Fakes the QuantumRunStream RPC.

Once a request is received, it is appended to `stream_requests`, and the test calling
Once a request is received, it is appended to `all_stream_requests`, and the test calling
`wait_for_requests()` is notified.

The response is sent when a test calls `reply()` with a `QuantumRunStreamResponse`. If a
Expand All @@ -91,25 +96,29 @@ async def quantum_run_stream(

This is called from the asyncio thread.
"""
responses_and_exceptions: asyncio.Queue = asyncio.Queue()
responses_and_exceptions: asyncio.Queue[
Union[quantum.QuantumRunStreamResponse, BaseException]
] = asyncio.Queue()
self._responses_and_exceptions_future.try_set_result(responses_and_exceptions)

async def read_requests():
async for request in requests:
self.all_stream_requests.append(request)
self._request_buffer.add(request)
await responses_and_exceptions.put(FakeQuantumRunStream._REQUEST_STOPPED)
self._request_iterator_stopped.try_set_result(None)

async def response_iterator():
asyncio.create_task(read_requests())
while True:
response_or_exception = await responses_and_exceptions.get()
if isinstance(response_or_exception, quantum.QuantumRunStreamResponse):
yield response_or_exception
else: # isinstance(response_or_exception, BaseException)
self._responses_and_exceptions_future = duet.AwaitableFuture[asyncio.Queue]()
raise response_or_exception
while (
message := await responses_and_exceptions.get()
) != FakeQuantumRunStream._REQUEST_STOPPED:
if isinstance(message, quantum.QuantumRunStreamResponse):
yield message
else: # isinstance(message, BaseException)
self._responses_and_exceptions_future = duet.AwaitableFuture()
raise message

await asyncio.sleep(0)
return response_iterator()

async def cancel_quantum_job(self, request: quantum.CancelQuantumJobRequest) -> None:
Expand Down Expand Up @@ -158,6 +167,14 @@ async def send():

await self._executor.submit(send)

async def wait_for_request_iterator_stop(self):
"""Wait for the request iterator to stop.

This must be called from a duet thread.
"""
await self._request_iterator_stopped
self._request_iterator_stopped = duet.AwaitableFuture()


class TestResponseDemux:
@pytest.fixture
Expand Down Expand Up @@ -704,3 +721,91 @@ def test_get_retry_request_or_raise_expects_stream_error(
create_quantum_program_and_job_request,
create_quantum_job_request,
)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_broken_stream_stops_request_iterator(self, client_constructor):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result,
)
)
await actual_result_future
await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable'))
await fake_client.wait_for_request_iterator_stop()
manager.stop()

duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_stop_stops_request_iterator(self, client_constructor):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result,
)
)
await actual_result_future
manager.stop()
await fake_client.wait_for_request_iterator_stop()

duet.run(test)

@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_submit_after_stream_breakage(self, client_constructor):
expected_result0 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
expected_result1 = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job1')
fake_client, manager = setup(client_constructor)

async def test():
async with duet.timeout_scope(5):
actual_result0_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[0].message_id,
result=expected_result0,
)
)
actual_result0 = await actual_result0_future
await fake_client.reply(google_exceptions.ServiceUnavailable('service unavailable'))
actual_result1_future = manager.submit(
REQUEST_PROJECT_NAME, REQUEST_PROGRAM, REQUEST_JOB0
)
await fake_client.wait_for_requests()
await fake_client.reply(
quantum.QuantumRunStreamResponse(
message_id=fake_client.all_stream_requests[1].message_id,
result=expected_result1,
)
)
actual_result1 = await actual_result1_future
manager.stop()

assert len(fake_client.all_stream_requests) == 2
assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[0]
assert 'create_quantum_program_and_job' in fake_client.all_stream_requests[1]
assert actual_result0 == expected_result0
assert actual_result1 == expected_result1

duet.run(test)
Loading