From 4dc36d54b5adb69b5c5abac1edd1e057e1dbf2e4 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 4 Oct 2023 23:17:31 +0000 Subject: [PATCH] Integrate StreamManager with run_sweep() (#6285) * Integrate StreamManager with run_sweep() * Added a feature flag defaulting to True. * Added logic to temporarily rewrite `processor_ids` to `processor_id` prior to full deprecation of `processor_ids` in `Engine.run_sweep()`. * [WIP] Addressed maffoo's comments * Addressed wcourtney's comments * Modified engine_processor_test to take into account the enable_streaming feature flag * Address wcourtney's comments * Rename job_response_future to job_result_future --- cirq-google/cirq_google/engine/engine.py | 48 ++- .../cirq_google/engine/engine_client.py | 96 +++++ .../cirq_google/engine/engine_client_test.py | 372 +++++++++++++++--- cirq-google/cirq_google/engine/engine_job.py | 31 +- .../cirq_google/engine/engine_job_test.py | 46 +++ .../engine/engine_processor_test.py | 68 +++- cirq-google/cirq_google/engine/engine_test.py | 113 ++---- 7 files changed, 646 insertions(+), 128 deletions(-) diff --git a/cirq-google/cirq_google/engine/engine.py b/cirq-google/cirq_google/engine/engine.py index 38fbb6d6a4d..71a1657eacd 100644 --- a/cirq-google/cirq_google/engine/engine.py +++ b/cirq-google/cirq_google/engine/engine.py @@ -88,6 +88,8 @@ def __init__( client: 'Optional[engine_client.EngineClient]' = None, timeout: Optional[int] = None, serializer: Serializer = CIRCUIT_SERIALIZER, + # TODO(#5996) Remove enable_streaming once the feature is stable. + enable_streaming: bool = True, ) -> None: """Context and client for using Quantum Engine. @@ -103,6 +105,9 @@ def __init__( timeout: Timeout for polling for results, in seconds. Default is to never timeout. serializer: Used to serialize circuits when running jobs. + enable_streaming: Feature gate for making Quantum Engine requests using the stream RPC. + If True, the Quantum Engine streaming RPC is used for creating jobs + and getting results. Otherwise, unary RPCs are used. Raises: ValueError: If either `service_args` and `verbose` were supplied @@ -115,6 +120,7 @@ def __init__( if self.proto_version == ProtoVersion.V1: raise ValueError('ProtoVersion V1 no longer supported') self.serializer = serializer + self.enable_streaming = enable_streaming if not client: client = engine_client.EngineClient(service_args=service_args, verbose=verbose) @@ -306,7 +312,7 @@ async def run_sweep_async( run_name: str = "", device_config_name: str = "", ) -> engine_job.EngineJob: - """Runs the supplied Circuit via Quantum Engine.Creates + """Runs the supplied Circuit via Quantum Engine. In contrast to run, this runs across multiple parameter sweeps, and does not block until a result is returned. @@ -355,6 +361,44 @@ async def run_sweep_async( ValueError: If either `run_name` and `device_config_name` are set but `processor_id` is empty. """ + + if self.context.enable_streaming: + # This logic is temporary prior to deprecating the processor_ids parameter. + # TODO(#6271) Remove after deprecating processor_ids elsewhere prior to v1.4. + if processor_ids: + if len(processor_ids) > 1: + raise ValueError("The use of multiple processors is no longer supported.") + if len(processor_ids) == 1 and not processor_id: + processor_id = processor_ids[0] + + if not program_id: + program_id = _make_random_id('prog-') + if not job_id: + job_id = _make_random_id('job-') + run_context = self.context._serialize_run_context(params, repetitions) + + job_result_future = self.context.client.run_job_over_stream( + project_id=self.project_id, + program_id=str(program_id), + program_description=program_description, + program_labels=program_labels, + code=self.context._serialize_program(program), + job_id=str(job_id), + run_context=run_context, + job_description=job_description, + job_labels=job_labels, + processor_id=processor_id, + run_name=run_name, + device_config_name=device_config_name, + ) + return engine_job.EngineJob( + self.project_id, + str(program_id), + str(job_id), + self.context, + job_result_future=job_result_future, + ) + engine_program = await self.create_program_async( program, program_id, description=program_description, labels=program_labels ) @@ -372,6 +416,7 @@ async def run_sweep_async( run_sweep = duet.sync(run_sweep_async) + # TODO(#5996) Migrate to stream client # TODO(#6271): Deprecate and remove processor_ids before v1.4 async def run_batch_async( self, @@ -475,6 +520,7 @@ async def run_batch_async( run_batch = duet.sync(run_batch_async) + # TODO(#5996) Migrate to stream client async def run_calibration_async( self, layers: List['cirq_google.CalibrationLayer'], diff --git a/cirq-google/cirq_google/engine/engine_client.py b/cirq-google/cirq_google/engine/engine_client.py index 017fa8a2041..9d2fb8bae92 100644 --- a/cirq-google/cirq_google/engine/engine_client.py +++ b/cirq-google/cirq_google/engine/engine_client.py @@ -39,6 +39,7 @@ from cirq._compat import deprecated_parameter from cirq_google.cloud import quantum from cirq_google.engine.asyncio_executor import AsyncioExecutor +from cirq_google.engine import stream_manager _M = TypeVar('_M', bound=proto.Message) _R = TypeVar('_R') @@ -106,6 +107,10 @@ async def make_client(): return self._executor.submit(make_client).result() + @cached_property + def _stream_manager(self) -> stream_manager.StreamManager: + return stream_manager.StreamManager(self.grpc_client) + async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: """Sends a request by invoking an asyncio callable.""" return await self._run_retry_async(func, request) @@ -736,6 +741,97 @@ async def get_job_results_async( get_job_results = duet.sync(get_job_results_async) + def run_job_over_stream( + self, + *, + project_id: str, + program_id: str, + code: any_pb2.Any, + run_context: any_pb2.Any, + program_description: Optional[str] = None, + program_labels: Optional[Dict[str, str]] = None, + job_id: str, + priority: Optional[int] = None, + job_description: Optional[str] = None, + job_labels: Optional[Dict[str, str]] = None, + processor_id: str = "", + run_name: str = "", + device_config_name: str = "", + ) -> duet.AwaitableFuture[Union[quantum.QuantumResult, quantum.QuantumJob]]: + """Runs a job with the given program and job information over a stream. + + Sends the request over the Quantum Engine QuantumRunStream bidirectional stream, and returns + a future for the stream response. The future will be completed with a `QuantumResult` if + the job is successful; otherwise, it will be completed with a QuantumJob. + + Args: + project_id: A project_id of the parent Google Cloud Project. + program_id: Unique ID of the program within the parent project. + code: Properly serialized program code. + run_context: Properly serialized run context. + program_description: An optional description to set on the program. + program_labels: Optional set of labels to set on the program. + job_id: Unique ID of the job within the parent program. + priority: Optional priority to run at, 0-1000. + job_description: Optional description to set on the job. + job_labels: Optional set of labels to set on the job. + processor_id: Processor id for running the program. If not set, + `processor_ids` will be used. + run_name: A unique identifier representing an automation run for the + specified processor. An Automation Run contains a collection of + device configurations for a processor. If specified, `processor_id` + is required to be set. + device_config_name: An identifier used to select the processor configuration + utilized to run the job. A configuration identifies the set of + available qubits, couplers, and supported gates in the processor. + If specified, `processor_id` is required to be set. + + Returns: + A future for the job result, or the job if the job has failed. + + Raises: + ValueError: If the priority is not between 0 and 1000. + ValueError: If `processor_id` is not set. + ValueError: If only one of `run_name` and `device_config_name` are specified. + """ + # Check program to run and program parameters. + if priority and not 0 <= priority < 1000: + raise ValueError('priority must be between 0 and 1000') + if not processor_id: + raise ValueError('Must specify a processor id when creating a job.') + if bool(run_name) ^ bool(device_config_name): + raise ValueError('Cannot specify only one of `run_name` and `device_config_name`') + + project_name = _project_name(project_id) + + program_name = _program_name_from_ids(project_id, program_id) + program = quantum.QuantumProgram(name=program_name, code=code) + if program_description: + program.description = program_description + if program_labels: + program.labels.update(program_labels) + + job = quantum.QuantumJob( + name=_job_name_from_ids(project_id, program_id, job_id), + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor=_processor_name_from_ids(project_id, processor_id), + device_config_key=quantum.DeviceConfigKey( + run_name=run_name, config_alias=device_config_name + ), + ) + ), + run_context=run_context, + ) + if priority: + job.scheduling_config.priority = priority + if job_description: + job.description = job_description + if job_labels: + job.labels.update(job_labels) + + return self._stream_manager.submit(project_name, program, job) + async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: """Returns a list of Processors that the user has visibility to in the current Engine project. The names of these processors are used to diff --git a/cirq-google/cirq_google/engine/engine_client_test.py b/cirq-google/cirq_google/engine/engine_client_test.py index b1fcec6e6ba..82273d9f29d 100644 --- a/cirq-google/cirq_google/engine/engine_client_test.py +++ b/cirq-google/cirq_google/engine/engine_client_test.py @@ -25,18 +25,25 @@ from google.protobuf.timestamp_pb2 import Timestamp from cirq_google.engine.engine_client import EngineClient, EngineException +import cirq_google.engine.stream_manager as engine_stream_manager from cirq_google.cloud import quantum -def setup_mock_(client_constructor): +def _setup_client_mock(client_constructor): grpc_client = mock.AsyncMock() client_constructor.return_value = grpc_client return grpc_client +def _setup_stream_manager_mock(manager_constructor): + stream_manager = mock.MagicMock() + manager_constructor.return_value = stream_manager + return stream_manager + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_program(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumProgram(name='projects/proj/programs/prog') grpc_client.create_quantum_program.return_value = result @@ -96,7 +103,7 @@ def test_create_program(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_program(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumProgram(name='projects/proj/programs/prog') grpc_client.get_quantum_program.return_value = result @@ -115,7 +122,7 @@ def test_get_program(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_program(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) results = [ quantum.QuantumProgram(name='projects/proj/programs/prog1'), @@ -163,7 +170,7 @@ def test_list_program(client_constructor): def test_list_program_filters( client_constructor, expected_filter, created_before, created_after, labels ): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) client = EngineClient() client.list_programs( project_id='proj', @@ -182,7 +189,7 @@ def test_list_program_filters_invalid_type(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_description(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumProgram(name='projects/proj/programs/prog') grpc_client.update_quantum_program.return_value = result @@ -211,7 +218,7 @@ def test_set_program_description(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_program_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_program.return_value = quantum.QuantumProgram( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -246,7 +253,7 @@ def test_set_program_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_program_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) existing = quantum.QuantumProgram( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -288,7 +295,7 @@ def test_add_program_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_program_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) existing = quantum.QuantumProgram( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -328,7 +335,7 @@ def test_remove_program_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_program(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) client = EngineClient() assert not client.delete_program('proj', 'prog') @@ -344,8 +351,8 @@ def test_delete_program(client_constructor): @mock.patch.dict(os.environ, clear='CIRQ_TESTING') @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) -def test_create_job_with_legacy_processor_ids(client_constructor): - grpc_client = setup_mock_(client_constructor) +def test_create_job(client_constructor): + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') grpc_client.create_quantum_job.return_value = result @@ -513,7 +520,7 @@ def test_create_job_with_legacy_processor_ids(client_constructor): def test_create_job_with_invalid_processor_and_device_config_arguments_throws( client_constructor, processor_ids, processor_id, run_name, device_config_name, error_message ): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') grpc_client.create_quantum_job.return_value = result client = EngineClient() @@ -539,7 +546,7 @@ def test_create_job_with_invalid_processor_and_device_config_arguments_throws( def test_create_job_with_run_name_and_device_config_name( client_constructor, processor_ids, processor_id, run_name, device_config_name ): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') grpc_client.create_quantum_job.return_value = result run_context = any_pb2.Any() @@ -576,9 +583,286 @@ def test_create_job_with_run_name_and_device_config_name( ) +@pytest.mark.parametrize( + 'run_job_kwargs, expected_submit_args', + [ + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'program_description': 'A program', + 'program_labels': {'hello': 'world'}, + 'priority': 10, + 'job_description': 'A job', + 'job_labels': {'hello': 'world'}, + }, + [ + 'projects/proj', + quantum.QuantumProgram( + name='projects/proj/programs/prog', + code=any_pb2.Any(), + description='A program', + labels={'hello': 'world'}, + ), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + priority=10, + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ), + ), + description='A job', + labels={'hello': 'world'}, + ), + ], + ), + # Missing program labels + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'program_description': 'A program', + 'priority': 10, + 'job_description': 'A job', + 'job_labels': {'hello': 'world'}, + }, + [ + 'projects/proj', + quantum.QuantumProgram( + name='projects/proj/programs/prog', code=any_pb2.Any(), description='A program' + ), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + priority=10, + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ), + ), + description='A job', + labels={'hello': 'world'}, + ), + ], + ), + # Missing program description and labels + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'priority': 10, + 'job_description': 'A job', + 'job_labels': {'hello': 'world'}, + }, + [ + 'projects/proj', + quantum.QuantumProgram(name='projects/proj/programs/prog', code=any_pb2.Any()), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + priority=10, + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ), + ), + description='A job', + labels={'hello': 'world'}, + ), + ], + ), + # Missing job labels + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'program_description': 'A program', + 'program_labels': {'hello': 'world'}, + 'priority': 10, + 'job_description': 'A job', + }, + [ + 'projects/proj', + quantum.QuantumProgram( + name='projects/proj/programs/prog', + code=any_pb2.Any(), + description='A program', + labels={'hello': 'world'}, + ), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + priority=10, + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ), + ), + description='A job', + ), + ], + ), + # Missing job description and labels + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'program_description': 'A program', + 'program_labels': {'hello': 'world'}, + 'priority': 10, + }, + [ + 'projects/proj', + quantum.QuantumProgram( + name='projects/proj/programs/prog', + code=any_pb2.Any(), + description='A program', + labels={'hello': 'world'}, + ), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + priority=10, + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ), + ), + ), + ], + ), + # Missing job priority, description, and labels + ( + { + 'project_id': 'proj', + 'program_id': 'prog', + 'code': any_pb2.Any(), + 'job_id': 'job0', + 'processor_id': 'processor0', + 'run_context': any_pb2.Any(), + 'program_description': 'A program', + 'program_labels': {'hello': 'world'}, + }, + [ + 'projects/proj', + quantum.QuantumProgram( + name='projects/proj/programs/prog', + code=any_pb2.Any(), + description='A program', + labels={'hello': 'world'}, + ), + quantum.QuantumJob( + name='projects/proj/programs/prog/jobs/job0', + run_context=any_pb2.Any(), + scheduling_config=quantum.SchedulingConfig( + processor_selector=quantum.SchedulingConfig.ProcessorSelector( + processor='projects/proj/processors/processor0', + device_config_key=quantum.DeviceConfigKey(), + ) + ), + ), + ], + ), + ], +) +@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) +@mock.patch.object(engine_stream_manager, 'StreamManager', autospec=True) +def test_run_job_over_stream( + manager_constructor, client_constructor, run_job_kwargs, expected_submit_args +): + _setup_client_mock(client_constructor) + stream_manager = _setup_stream_manager_mock(manager_constructor) + + result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') + expected_future = duet.AwaitableFuture() + expected_future.try_set_result(result) + stream_manager.submit.return_value = expected_future + client = EngineClient() + + actual_future = client.run_job_over_stream(**run_job_kwargs) + + assert actual_future == expected_future + stream_manager.submit.assert_called_with(*expected_submit_args) + + +def test_run_job_over_stream_with_priority_out_of_bound_raises(): + client = EngineClient() + + with pytest.raises(ValueError): + client.run_job_over_stream( + project_id='proj', + program_id='prog', + code=any_pb2.Any(), + job_id='job0', + processor_id='processor0', + run_context=any_pb2.Any(), + priority=9001, + ) + + +def test_run_job_over_stream_processor_unset_raises(): + client = EngineClient() + + with pytest.raises(ValueError, match='Must specify a processor id'): + client.run_job_over_stream( + project_id='proj', + program_id='prog', + code=any_pb2.Any(), + job_id='job0', + processor_id='', + run_context=any_pb2.Any(), + ) + + +@pytest.mark.parametrize('run_name, device_config_name', [('run1', ''), ('', 'device_config1')]) +def test_run_job_over_stream_invalid_device_config_raises(run_name, device_config_name): + client = EngineClient() + + with pytest.raises( + ValueError, match='Cannot specify only one of `run_name` and `device_config_name`' + ): + client.run_job_over_stream( + project_id='proj', + program_id='prog', + code=any_pb2.Any(), + job_id='job0', + processor_id='mysim', + run_context=any_pb2.Any(), + run_name=run_name, + device_config_name=device_config_name, + ) + + @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_job(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') grpc_client.get_quantum_job.return_value = result @@ -601,7 +885,7 @@ def test_get_job(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_description(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumJob(name='projects/proj/programs/prog/jobs/job0') grpc_client.update_quantum_job.return_value = result @@ -630,7 +914,7 @@ def test_set_job_description(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_set_job_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_job.return_value = quantum.QuantumJob( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -667,7 +951,7 @@ def test_set_job_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_add_job_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) existing = quantum.QuantumJob( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -711,7 +995,7 @@ def test_add_job_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_remove_job_labels(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) existing = quantum.QuantumJob( labels={'color': 'red', 'weather': 'sun', 'run': '1'}, label_fingerprint='hash' @@ -751,7 +1035,7 @@ def test_remove_job_labels(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_job(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) client = EngineClient() assert not client.delete_job('proj', 'prog', 'job0') @@ -762,7 +1046,7 @@ def test_delete_job(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_job(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) client = EngineClient() assert not client.cancel_job('proj', 'prog', 'job0') @@ -773,7 +1057,7 @@ def test_cancel_job(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_job_results(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') grpc_client.get_quantum_result.return_value = result @@ -787,7 +1071,7 @@ def test_job_results(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_jobs(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) results = [ quantum.QuantumJob(name='projects/proj/programs/prog1/jobs/job1'), @@ -900,7 +1184,7 @@ def test_list_jobs_filters( executed_processor_ids, scheduled_processor_ids, ): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) client = EngineClient() client.list_jobs( project_id='proj', @@ -929,7 +1213,7 @@ async def __aiter__(self): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_processors(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) results = [ quantum.QuantumProcessor(name='projects/proj/processor/processor0'), @@ -946,7 +1230,7 @@ def test_list_processors(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_processor(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumProcessor(name='projects/proj/processors/processor0') grpc_client.get_quantum_processor.return_value = result @@ -960,7 +1244,7 @@ def test_get_processor(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_calibrations(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) results = [ quantum.QuantumCalibration(name='projects/proj/processor/processor0/calibrations/123456'), @@ -977,7 +1261,7 @@ def test_list_calibrations(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_calibration(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumCalibration( name='projects/proj/processors/processor0/calibrations/123456' @@ -995,7 +1279,7 @@ def test_get_calibration(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) result = quantum.QuantumCalibration( name='projects/proj/processors/processor0/calibrations/123456' @@ -1013,7 +1297,7 @@ def test_get_current_calibration(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_does_not_exist(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_calibration.side_effect = exceptions.NotFound('not found') @@ -1028,7 +1312,7 @@ def test_get_current_calibration_does_not_exist(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_current_calibration_error(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_calibration.side_effect = exceptions.BadRequest('boom') @@ -1039,7 +1323,7 @@ def test_get_current_calibration_error(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_doesnt_retry_not_found_errors(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.NotFound('not found') client = EngineClient() @@ -1050,7 +1334,7 @@ def test_api_doesnt_retry_not_found_errors(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_retry_5xx_errors(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.ServiceUnavailable('internal error') client = EngineClient(max_retry_delay_seconds=0.3) @@ -1062,7 +1346,7 @@ def test_api_retry_5xx_errors(client_constructor): @mock.patch('duet.sleep', return_value=duet.completed_future(None)) @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_api_retry_times(client_constructor, mock_sleep): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_program.side_effect = exceptions.ServiceUnavailable('internal error') client = EngineClient(max_retry_delay_seconds=0.3) @@ -1076,7 +1360,7 @@ def test_api_retry_times(client_constructor, mock_sleep): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_create_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) start = datetime.datetime.fromtimestamp(1000000000) end = datetime.datetime.fromtimestamp(1000003600) users = ['jeff@google.com'] @@ -1102,7 +1386,7 @@ def test_create_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_cancel_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' result = quantum.QuantumReservation( name=name, @@ -1121,7 +1405,7 @@ def test_cancel_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_delete_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' result = quantum.QuantumReservation( name=name, @@ -1140,7 +1424,7 @@ def test_delete_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' result = quantum.QuantumReservation( name=name, @@ -1159,7 +1443,7 @@ def test_get_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_not_found(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' grpc_client.get_quantum_reservation.side_effect = exceptions.NotFound('not found') @@ -1172,7 +1456,7 @@ def test_get_reservation_not_found(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_get_reservation_exception(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) grpc_client.get_quantum_reservation.side_effect = exceptions.BadRequest('boom') client = EngineClient() @@ -1182,7 +1466,7 @@ def test_get_reservation_exception(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' results = [ quantum.QuantumReservation( @@ -1206,7 +1490,7 @@ def test_list_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' result = quantum.QuantumReservation( name=name, @@ -1239,7 +1523,7 @@ def test_update_reservation(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_update_reservation_remove_all_users(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) name = 'projects/proj/processors/processor0/reservations/papar-party-44' result = quantum.QuantumReservation(name=name, whitelisted_users=[]) grpc_client.update_quantum_reservation.return_value = result @@ -1260,7 +1544,7 @@ def test_update_reservation_remove_all_users(client_constructor): @mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) def test_list_time_slots(client_constructor): - grpc_client = setup_mock_(client_constructor) + grpc_client = _setup_client_mock(client_constructor) results = [ quantum.QuantumTimeSlot( processor_name='potofgold', diff --git a/cirq-google/cirq_google/engine/engine_job.py b/cirq-google/cirq_google/engine/engine_job.py index 5fbfca657f2..7d3db35dd88 100644 --- a/cirq-google/cirq_google/engine/engine_job.py +++ b/cirq-google/cirq_google/engine/engine_job.py @@ -14,7 +14,7 @@ """A helper for jobs that have been created on the Quantum Engine.""" import datetime -from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import duet from google.protobuf import any_pb2 @@ -69,6 +69,9 @@ def __init__( context: 'engine_base.EngineContext', _job: Optional[quantum.QuantumJob] = None, result_type: ResultType = ResultType.Program, + job_result_future: Optional[ + duet.AwaitableFuture[Union[quantum.QuantumResult, quantum.QuantumJob]] + ] = None, ) -> None: """A job submitted to the engine. @@ -79,7 +82,10 @@ def __init__( context: Engine configuration and context to use. _job: The optional current job state. result_type: What type of results are expected, such as - batched results or the result of a focused calibration. + batched results or the result of a focused calibration. + job_result_future: A future to be completed when the job result is available. + If set, EngineJob will await this future when a caller asks for the job result. If + the future is completed with a `QuantumJob`, it is assumed that the job has failed. """ self.project_id = project_id self.program_id = program_id @@ -90,6 +96,7 @@ def __init__( self._calibration_results: Optional[Sequence[CalibrationResult]] = None self._batched_results: Optional[Sequence[Sequence[EngineResult]]] = None self.result_type = result_type + self._job_result_future = job_result_future def id(self) -> str: """Returns the job id.""" @@ -279,7 +286,8 @@ async def results_async(self) -> Sequence[EngineResult]: import cirq_google.engine.engine as engine_base if self._results is None: - result = await self._await_result_async() + result_response = await self._await_result_async() + result = result_response.result result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if ( result_type == 'cirq.google.api.v1.Result' @@ -302,6 +310,18 @@ async def results_async(self) -> Sequence[EngineResult]: return self._results async def _await_result_async(self) -> quantum.QuantumResult: + if self._job_result_future is not None: + response = await self._job_result_future + if isinstance(response, quantum.QuantumResult): + return response + elif isinstance(response, quantum.QuantumJob): + self._job = response + _raise_on_failure(response) + else: + raise ValueError( + 'Internal error: The job response type is not recognized.' + ) # pragma: no cover + async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type] while True: job = await self._refresh_job_async() @@ -312,7 +332,7 @@ async def _await_result_async(self) -> quantum.QuantumResult: response = await self.context.client.get_job_results_async( self.project_id, self.program_id, self.job_id ) - return response.result + return response async def calibration_results_async(self) -> Sequence[CalibrationResult]: """Returns the results of a run_calibration() call. @@ -323,7 +343,8 @@ async def calibration_results_async(self) -> Sequence[CalibrationResult]: import cirq_google.engine.engine as engine_base if self._calibration_results is None: - result = await self._await_result_async() + result_response = await self._await_result_async() + result = result_response.result result_type = result.type_url[len(engine_base.TYPE_PREFIX) :] if result_type != 'cirq.google.api.v2.FocusedCalibrationResult': raise ValueError(f'Did not find calibration results, instead found: {result_type}') diff --git a/cirq-google/cirq_google/engine/engine_job_test.py b/cirq-google/cirq_google/engine/engine_job_test.py index 82c07398300..1786d2651be 100644 --- a/cirq-google/cirq_google/engine/engine_job_test.py +++ b/cirq-google/cirq_google/engine/engine_job_test.py @@ -16,6 +16,7 @@ from unittest import mock import pytest +import duet from google.protobuf import any_pb2, timestamp_pb2 from google.protobuf.text_format import Merge @@ -550,6 +551,51 @@ def test_results_getitem(get_job_results): _ = job[2] +def test_receives_results_via_stream_returns_correct_results(): + qjob = quantum.QuantumJob( + execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS), + update_time=UPDATE_TIME, + ) + result_future = duet.completed_future(RESULTS) + + job = cg.EngineJob( + 'a', 'b', 'steve', EngineContext(), _job=qjob, job_result_future=result_future + ) + data = job.results() + + assert len(data) == 2 + assert str(data[0]) == 'q=0110' + assert str(data[1]) == 'q=1010' + + +def test_receives_job_via_stream_raises_and_updates_underlying_job(): + expected_error_code = quantum.ExecutionStatus.Failure.Code.SYSTEM_ERROR + expected_error_message = 'system error' + qjob = quantum.QuantumJob( + execution_status=quantum.ExecutionStatus( + state=quantum.ExecutionStatus.State.SUCCESS, + failure=quantum.ExecutionStatus.Failure( + error_code=expected_error_code, error_message=expected_error_message + ), + ), + update_time=UPDATE_TIME, + ) + result_future = duet.completed_future(qjob) + + job = cg.EngineJob( + 'a', 'b', 'steve', EngineContext(), _job=qjob, job_result_future=result_future + ) + qjob.execution_status.state = quantum.ExecutionStatus.State.FAILURE + + with pytest.raises(RuntimeError): + job.results() + actual_error_code, actual_error_message = job.failure() + + # Checks that the underlying job has been updated by checking failure information. + assert actual_error_code == expected_error_code.name + assert actual_error_message == expected_error_message + + @mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async') def test_batched_results(get_job_results): qjob = quantum.QuantumJob( diff --git a/cirq-google/cirq_google/engine/engine_processor_test.py b/cirq-google/cirq_google/engine/engine_processor_test.py index 61ac367704f..0273a70c0c7 100644 --- a/cirq-google/cirq_google/engine/engine_processor_test.py +++ b/cirq-google/cirq_google/engine/engine_processor_test.py @@ -15,6 +15,7 @@ from unittest import mock import datetime +import duet import pytest import freezegun import numpy as np @@ -797,7 +798,7 @@ def test_list_reservations_time_filter_behavior(list_reservations): @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) -def test_run_sweep_params(client): +def test_run_sweep_params_with_unary_rpcs(client): client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), @@ -815,7 +816,7 @@ def test_run_sweep_params(client): result=util.pack_any(_RESULTS_V2) ) - processor = cg.EngineProcessor('a', 'p', EngineContext()) + processor = cg.EngineProcessor('a', 'p', EngineContext(enable_streaming=False)) job = processor.run_sweep( program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})] ) @@ -844,6 +845,42 @@ def test_run_sweep_params(client): client().get_job_results_async.assert_called_once() +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) +def test_run_sweep_params_with_stream_rpcs(client): + client().get_job_async.return_value = quantum.QuantumJob( + execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') + ) + expected_result = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(expected_result) + client().run_job_over_stream.return_value = stream_future + + processor = cg.EngineProcessor('a', 'p', EngineContext(enable_streaming=True)) + job = processor.run_sweep( + program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})] + ) + results = job.results() + assert len(results) == 2 + for i, v in enumerate([1, 2]): + assert results[i].repetitions == 1 + assert results[i].params.param_dict == {'a': v} + assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} + for result in results: + assert result.job_id == job.id() + assert result.job_finished_time is not None + assert results == cirq.read_json(json_text=cirq.to_json(results)) + + client().run_job_over_stream.assert_called_once() + + run_context = v2.run_context_pb2.RunContext() + client().run_job_over_stream.call_args[1]['run_context'].Unpack(run_context) + sweeps = run_context.parameter_sweeps + assert len(sweeps) == 2 + for i, v in enumerate([1.0, 2.0]): + assert sweeps[i].repetitions == 1 + assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] + + @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_batch(client): client().create_program_async.return_value = ( @@ -941,7 +978,7 @@ def test_run_calibration(client): @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) -def test_sampler(client): +def test_sampler_with_unary_rpcs(client): client().create_program_async.return_value = ( 'prog', quantum.QuantumProgram(name='projects/proj/programs/prog'), @@ -958,7 +995,7 @@ def test_sampler(client): client().get_job_results_async.return_value = quantum.QuantumResult( result=util.pack_any(_RESULTS_V2) ) - processor = cg.EngineProcessor('proj', 'mysim', EngineContext()) + processor = cg.EngineProcessor('proj', 'mysim', EngineContext(enable_streaming=False)) sampler = processor.get_sampler() results = sampler.run_sweep( program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})] @@ -971,6 +1008,29 @@ def test_sampler(client): assert client().create_program_async.call_args[0][0] == 'proj' +@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) +def test_sampler_with_stream_rpcs(client): + client().get_job_async.return_value = quantum.QuantumJob( + execution_status={'state': 'SUCCESS'}, update_time=_to_timestamp('2019-07-09T23:39:59Z') + ) + expected_result = quantum.QuantumResult(result=util.pack_any(_RESULTS_V2)) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(expected_result) + client().run_job_over_stream.return_value = stream_future + + processor = cg.EngineProcessor('proj', 'mysim', EngineContext(enable_streaming=True)) + sampler = processor.get_sampler() + results = sampler.run_sweep( + program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})] + ) + assert len(results) == 2 + for i, v in enumerate([1, 2]): + assert results[i].repetitions == 1 + assert results[i].params.param_dict == {'a': v} + assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} + assert client().run_job_over_stream.call_args[1]['project_id'] == 'proj' + + def test_str(): processor = cg.EngineProcessor('a', 'p', EngineContext()) assert str(processor) == 'EngineProcessor(project_id=\'a\', processor_id=\'p\')' diff --git a/cirq-google/cirq_google/engine/engine_test.py b/cirq-google/cirq_google/engine/engine_test.py index 64c65ff8324..f34c3e0851a 100644 --- a/cirq-google/cirq_google/engine/engine_test.py +++ b/cirq-google/cirq_google/engine/engine_test.py @@ -19,6 +19,7 @@ import numpy as np import pytest +import duet from google.protobuf import any_pb2, timestamp_pb2 from google.protobuf.text_format import Merge @@ -348,6 +349,9 @@ def setup_run_circuit_with_result_(client, result): execution_status={'state': 'SUCCESS'}, update_time=_DT ) client().get_job_results_async.return_value = quantum.QuantumResult(result=result) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(quantum.QuantumResult(result=result)) + client().run_job_over_stream.return_value = stream_future @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) @@ -363,25 +367,24 @@ def test_run_circuit(client): assert result.params.param_dict == {'a': 1} assert result.measurements == {'q': np.array([[0]], dtype='uint8')} client.assert_called_with(service_args={'client_info': 1}, verbose=None) - client().create_program_async.assert_called_once() - client().create_job_async.assert_called_once_with( + client().run_job_over_stream.assert_called_once_with( project_id='proj', program_id='prog', + code=mock.ANY, job_id='job-id', - processor_ids=['mysim'], run_context=util.pack_any( v2.run_context_pb2.RunContext( parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)] ) ), - description=None, - labels=None, - processor_id='', + program_description=None, + program_labels=None, + job_description=None, + job_labels=None, + processor_id='mysim', run_name='', device_config_name='', ) - client().get_job_async.assert_called_once_with('proj', 'prog', 'job-id', False) - client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id') def test_no_gate_set(): @@ -397,17 +400,7 @@ def test_unsupported_program_type(): @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed(client): - client().create_program_async.return_value = ( - 'prog', - quantum.QuantumProgram(name='projects/proj/programs/prog'), - ) - client().create_job_async.return_value = ( - 'job-id', - quantum.QuantumJob( - name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} - ), - ) - client().get_job_async.return_value = quantum.QuantumJob( + failed_job = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', @@ -415,6 +408,9 @@ def test_run_circuit_failed(client): 'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'}, }, ) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(failed_job) + client().run_job_over_stream.return_value = stream_future engine = cg.Engine(project_id='proj') with pytest.raises( @@ -427,23 +423,16 @@ def test_run_circuit_failed(client): @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_failed_missing_processor_name(client): - client().create_program_async.return_value = ( - 'prog', - quantum.QuantumProgram(name='projects/proj/programs/prog'), - ) - client().create_job_async.return_value = ( - 'job-id', - quantum.QuantumJob( - name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} - ), - ) - client().get_job_async.return_value = quantum.QuantumJob( + failed_job = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={ 'state': 'FAILURE', 'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'}, }, ) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(failed_job) + client().run_job_over_stream.return_value = stream_future engine = cg.Engine(project_id='proj') with pytest.raises( @@ -456,19 +445,12 @@ def test_run_circuit_failed_missing_processor_name(client): @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_circuit_cancelled(client): - client().create_program_async.return_value = ( - 'prog', - quantum.QuantumProgram(name='projects/proj/programs/prog'), - ) - client().create_job_async.return_value = ( - 'job-id', - quantum.QuantumJob( - name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} - ), - ) - client().get_job_async.return_value = quantum.QuantumJob( + canceled_job = quantum.QuantumJob( name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'} ) + stream_future = duet.AwaitableFuture() + stream_future.try_set_result(canceled_job) + client().run_job_over_stream.return_value = stream_future engine = cg.Engine(project_id='proj') with pytest.raises( @@ -477,27 +459,6 @@ def test_run_circuit_cancelled(client): engine.run(program=_CIRCUIT) -@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) -def test_run_circuit_timeout(client): - client().create_program_async.return_value = ( - 'prog', - quantum.QuantumProgram(name='projects/proj/programs/prog'), - ) - client().create_job_async.return_value = ( - 'job-id', - quantum.QuantumJob( - name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'} - ), - ) - client().get_job_async.return_value = quantum.QuantumJob( - name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'RUNNING'} - ) - - engine = cg.Engine(project_id='project-id', timeout=1) - with pytest.raises(TimeoutError): - engine.run(program=_CIRCUIT) - - @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) def test_run_sweep_params(client): setup_run_circuit_with_result_(client, _RESULTS) @@ -513,18 +474,25 @@ def test_run_sweep_params(client): assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program_async.assert_called_once() - client().create_job_async.assert_called_once() + client().run_job_over_stream.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job_async.call_args[1]['run_context'].Unpack(run_context) + client().run_job_over_stream.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 2 for i, v in enumerate([1.0, 2.0]): assert sweeps[i].repetitions == 1 assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v] - client().get_job_async.assert_called_once() - client().get_job_results_async.assert_called_once() + + +def test_run_sweep_with_multiple_processor_ids(): + engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2) + with pytest.raises(ValueError, match='multiple processors is no longer supported'): + _ = engine.run_sweep( + program=_CIRCUIT, + params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})], + processor_ids=['mysim', 'mysim2'], + ) @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) @@ -570,16 +538,13 @@ def test_run_sweep_v2(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - client().create_program_async.assert_called_once() - client().create_job_async.assert_called_once() + client().run_job_over_stream.assert_called_once() run_context = v2.run_context_pb2.RunContext() - client().create_job_async.call_args[1]['run_context'].Unpack(run_context) + client().run_job_over_stream.call_args[1]['run_context'].Unpack(run_context) sweeps = run_context.parameter_sweeps assert len(sweeps) == 1 assert sweeps[0].repetitions == 1 assert sweeps[0].sweep.single_sweep.points.points == [1, 2] - client().get_job_async.assert_called_once() - client().get_job_results_async.assert_called_once() @mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True) @@ -735,7 +700,7 @@ def test_bad_program_proto(): engine = cg.Engine( project_id='project-id', proto_version=cg.engine.engine.ProtoVersion.UNDEFINED ) - with pytest.raises(ValueError, match='invalid program proto version'): + with pytest.raises(ValueError, match='invalid (program|run context) proto version'): engine.run_sweep(program=_CIRCUIT) with pytest.raises(ValueError, match='invalid program proto version'): engine.create_program(_CIRCUIT) @@ -820,7 +785,7 @@ def test_sampler(client): assert results[i].repetitions == 1 assert results[i].params.param_dict == {'a': v} assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')} - assert client().create_program_async.call_args[0][0] == 'proj' + assert client().run_job_over_stream.call_args[1]['project_id'] == 'proj' with cirq.testing.assert_deprecated('sampler', deadline='1.0'): _ = engine.sampler(processor_id='tmp')