Skip to content

Commit

Permalink
Create separate unary and streaming RPC tests in engine_test.py (#6311)
Browse files Browse the repository at this point in the history
  • Loading branch information
verult committed Oct 5, 2023
1 parent 4dc36d5 commit 87f77be
Showing 1 changed file with 228 additions and 15 deletions.
243 changes: 228 additions & 15 deletions cirq-google/cirq_google/engine/engine_test.py
Expand Up @@ -355,10 +355,50 @@ def setup_run_circuit_with_result_(client, result):


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit(client):
def test_run_circuit_with_unary_rpcs(client):
setup_run_circuit_with_result_(client, _A_RESULT)

engine = cg.Engine(project_id='proj', service_args={'client_info': 1})
engine = cg.Engine(
project_id='proj',
context=EngineContext(service_args={'client_info': 1}, enable_streaming=False),
)
result = engine.run(
program=_CIRCUIT, program_id='prog', job_id='job-id', processor_ids=['mysim']
)

assert result.repetitions == 1
assert result.params.param_dict == {'a': 1}
assert result.measurements == {'q': np.array([[0]], dtype='uint8')}
client.assert_called_with(service_args={'client_info': 1}, verbose=None)
client().create_program_async.assert_called_once()
client().create_job_async.assert_called_once_with(
project_id='proj',
program_id='prog',
job_id='job-id',
processor_ids=['mysim'],
run_context=util.pack_any(
v2.run_context_pb2.RunContext(
parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=1)]
)
),
description=None,
labels=None,
processor_id='',
run_name='',
device_config_name='',
)
client().get_job_async.assert_called_once_with('proj', 'prog', 'job-id', False)
client().get_job_results_async.assert_called_once_with('proj', 'prog', 'job-id')


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_with_stream_rpcs(client):
setup_run_circuit_with_result_(client, _A_RESULT)

engine = cg.Engine(
project_id='proj',
context=EngineContext(service_args={'client_info': 1}, enable_streaming=True),
)
result = engine.run(
program=_CIRCUIT, program_id='prog', job_id='job-id', processor_ids=['mysim']
)
Expand Down Expand Up @@ -399,7 +439,37 @@ def test_unsupported_program_type():


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_failed(client):
def test_run_circuit_failed_with_unary_rpcs(client):
client().create_program_async.return_value = (
'prog',
quantum.QuantumProgram(name='projects/proj/programs/prog'),
)
client().create_job_async.return_value = (
'job-id',
quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
),
)
client().get_job_async.return_value = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id',
execution_status={
'state': 'FAILURE',
'processor_name': 'myqc',
'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'},
},
)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
with pytest.raises(
RuntimeError,
match='Job projects/proj/programs/prog/jobs/job-id on processor'
' myqc failed. SYSTEM_ERROR: Not good',
):
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_failed_with_stream_rpcs(client):
failed_job = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id',
execution_status={
Expand All @@ -412,7 +482,7 @@ def test_run_circuit_failed(client):
stream_future.try_set_result(failed_job)
client().run_job_over_stream.return_value = stream_future

engine = cg.Engine(project_id='proj')
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
with pytest.raises(
RuntimeError,
match='Job projects/proj/programs/prog/jobs/job-id on processor'
Expand All @@ -422,7 +492,36 @@ def test_run_circuit_failed(client):


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_failed_missing_processor_name(client):
def test_run_circuit_failed_missing_processor_name_with_unary_rpcs(client):
client().create_program_async.return_value = (
'prog',
quantum.QuantumProgram(name='projects/proj/programs/prog'),
)
client().create_job_async.return_value = (
'job-id',
quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
),
)
client().get_job_async.return_value = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id',
execution_status={
'state': 'FAILURE',
'failure': {'error_code': 'SYSTEM_ERROR', 'error_message': 'Not good'},
},
)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
with pytest.raises(
RuntimeError,
match='Job projects/proj/programs/prog/jobs/job-id on processor'
' UNKNOWN failed. SYSTEM_ERROR: Not good',
):
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_failed_missing_processor_name_with_stream_rpcs(client):
failed_job = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id',
execution_status={
Expand All @@ -434,7 +533,7 @@ def test_run_circuit_failed_missing_processor_name(client):
stream_future.try_set_result(failed_job)
client().run_job_over_stream.return_value = stream_future

engine = cg.Engine(project_id='proj')
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
with pytest.raises(
RuntimeError,
match='Job projects/proj/programs/prog/jobs/job-id on processor'
Expand All @@ -444,26 +543,78 @@ def test_run_circuit_failed_missing_processor_name(client):


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_cancelled(client):
def test_run_circuit_cancelled_with_unary_rpcs(client):
client().create_program_async.return_value = (
'prog',
quantum.QuantumProgram(name='projects/proj/programs/prog'),
)
client().create_job_async.return_value = (
'job-id',
quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'READY'}
),
)
client().get_job_async.return_value = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'}
)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
with pytest.raises(
RuntimeError, match='Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
):
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_circuit_cancelled_with_stream_rpcs(client):
canceled_job = quantum.QuantumJob(
name='projects/proj/programs/prog/jobs/job-id', execution_status={'state': 'CANCELLED'}
)
stream_future = duet.AwaitableFuture()
stream_future.try_set_result(canceled_job)
client().run_job_over_stream.return_value = stream_future

engine = cg.Engine(project_id='proj')
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
with pytest.raises(
RuntimeError, match='Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
):
engine.run(program=_CIRCUIT)


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_params(client):
def test_run_sweep_params_with_unary_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
job = engine.run_sweep(
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
)
results = job.results()
assert len(results) == 2
for i, v in enumerate([1, 2]):
assert results[i].repetitions == 1
assert results[i].params.param_dict == {'a': v}
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}

client().create_program_async.assert_called_once()
client().create_job_async.assert_called_once()

run_context = v2.run_context_pb2.RunContext()
client().create_job_async.call_args[1]['run_context'].Unpack(run_context)
sweeps = run_context.parameter_sweeps
assert len(sweeps) == 2
for i, v in enumerate([1.0, 2.0]):
assert sweeps[i].repetitions == 1
assert sweeps[i].sweep.sweep_function.sweeps[0].single_sweep.points.points == [v]
client().get_job_async.assert_called_once()
client().get_job_results_async.assert_called_once()


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_params_with_stream_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj')
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
job = engine.run_sweep(
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
)
Expand All @@ -486,7 +637,12 @@ def test_run_sweep_params(client):


def test_run_sweep_with_multiple_processor_ids():
engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2)
engine = cg.Engine(
project_id='proj',
context=EngineContext(
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=True
),
)
with pytest.raises(ValueError, match='multiple processors is no longer supported'):
_ = engine.run_sweep(
program=_CIRCUIT,
Expand Down Expand Up @@ -527,10 +683,44 @@ def test_run_multiple_times(client):


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_v2(client):
def test_run_sweep_v2_with_unary_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS_V2)

engine = cg.Engine(project_id='proj', proto_version=cg.engine.engine.ProtoVersion.V2)
engine = cg.Engine(
project_id='proj',
context=EngineContext(
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=False
),
)
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
results = job.results()
assert len(results) == 2
for i, v in enumerate([1, 2]):
assert results[i].repetitions == 1
assert results[i].params.param_dict == {'a': v}
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
client().create_program_async.assert_called_once()
client().create_job_async.assert_called_once()
run_context = v2.run_context_pb2.RunContext()
client().create_job_async.call_args[1]['run_context'].Unpack(run_context)
sweeps = run_context.parameter_sweeps
assert len(sweeps) == 1
assert sweeps[0].repetitions == 1
assert sweeps[0].sweep.single_sweep.points.points == [1, 2]
client().get_job_async.assert_called_once()
client().get_job_results_async.assert_called_once()


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_v2_with_stream_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS_V2)

engine = cg.Engine(
project_id='proj',
context=EngineContext(
proto_version=cg.engine.engine.ProtoVersion.V2, enable_streaming=True
),
)
job = engine.run_sweep(program=_CIRCUIT, job_id='job-id', params=cirq.Points('a', [1, 2]))
results = job.results()
assert len(results) == 2
Expand Down Expand Up @@ -772,10 +962,33 @@ def test_get_processor():


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_sampler(client):
def test_sampler_with_unary_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
sampler = engine.get_sampler(processor_id='tmp')
results = sampler.run_sweep(
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
)
assert len(results) == 2
for i, v in enumerate([1, 2]):
assert results[i].repetitions == 1
assert results[i].params.param_dict == {'a': v}
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}
assert client().create_program_async.call_args[0][0] == 'proj'

with cirq.testing.assert_deprecated('sampler', deadline='1.0'):
_ = engine.sampler(processor_id='tmp')

with pytest.raises(ValueError, match='list of processors'):
_ = engine.get_sampler(['test1', 'test2'])


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_sampler_with_stream_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS)

engine = cg.Engine(project_id='proj')
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=True))
sampler = engine.get_sampler(processor_id='tmp')
results = sampler.run_sweep(
program=_CIRCUIT, params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})]
Expand Down

0 comments on commit 87f77be

Please sign in to comment.