Skip to content

Commit

Permalink
Update cirq_google to import from cirq directly (#4156)
Browse files Browse the repository at this point in the history
Now that cirq-google is a separate package from cirq-core, we can import cirq like a separate library. We use the top-level `cirq` namespace wherever possible instead of importing subpackages, and also remove some of the `if TYPE_CHECKING` guards and unquote cirq type annotations.
  • Loading branch information
maffoo committed Jun 2, 2021
1 parent bc4123e commit 0311e17
Show file tree
Hide file tree
Showing 42 changed files with 869 additions and 959 deletions.
47 changes: 24 additions & 23 deletions cirq-google/cirq_google/api/v1/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,82 +13,83 @@
# limitations under the License.
from typing import cast

from cirq_google.api.v1 import params_pb2
import cirq
from cirq.study import sweeps
from cirq_google.api.v1 import params_pb2


def sweep_to_proto(sweep: sweeps.Sweep, repetitions: int = 1) -> params_pb2.ParameterSweep:
def sweep_to_proto(sweep: cirq.Sweep, repetitions: int = 1) -> params_pb2.ParameterSweep:
"""Converts sweep into an equivalent protobuf representation."""
product_sweep = None
if not sweep == sweeps.UnitSweep:
if sweep != cirq.UnitSweep:
sweep = _to_zip_product(sweep)
product_sweep = params_pb2.ProductSweep(
factors=[_sweep_zip_to_proto(cast(sweeps.Zip, factor)) for factor in sweep.factors]
factors=[_sweep_zip_to_proto(cast(cirq.Zip, factor)) for factor in sweep.factors]
)
msg = params_pb2.ParameterSweep(repetitions=repetitions, sweep=product_sweep)
return msg


def _to_zip_product(sweep: sweeps.Sweep) -> sweeps.Product:
def _to_zip_product(sweep: cirq.Sweep) -> cirq.Product:
"""Converts sweep to a product of zips of single sweeps, if possible."""
if not isinstance(sweep, sweeps.Product):
sweep = sweeps.Product(sweep)
if not all(isinstance(f, sweeps.Zip) for f in sweep.factors):
factors = [f if isinstance(f, sweeps.Zip) else sweeps.Zip(f) for f in sweep.factors]
sweep = sweeps.Product(*factors)
if not isinstance(sweep, cirq.Product):
sweep = cirq.Product(sweep)
if not all(isinstance(f, cirq.Zip) for f in sweep.factors):
factors = [f if isinstance(f, cirq.Zip) else cirq.Zip(f) for f in sweep.factors]
sweep = cirq.Product(*factors)
for factor in sweep.factors:
for term in cast(sweeps.Zip, factor).sweeps:
for term in cast(cirq.Zip, factor).sweeps:
if not isinstance(term, sweeps.SingleSweep):
raise ValueError(f'cannot convert to zip-product form: {sweep}')
return sweep


def _sweep_zip_to_proto(sweep: sweeps.Zip) -> params_pb2.ZipSweep:
def _sweep_zip_to_proto(sweep: cirq.Zip) -> params_pb2.ZipSweep:
sweep_list = [_single_param_sweep_to_proto(cast(sweeps.SingleSweep, s)) for s in sweep.sweeps]
return params_pb2.ZipSweep(sweeps=sweep_list)


def _single_param_sweep_to_proto(sweep: sweeps.SingleSweep) -> params_pb2.SingleSweep:
if isinstance(sweep, sweeps.Linspace):
if isinstance(sweep, cirq.Linspace):
return params_pb2.SingleSweep(
parameter_key=sweep.key,
linspace=params_pb2.Linspace(
first_point=sweep.start, last_point=sweep.stop, num_points=sweep.length
),
)
elif isinstance(sweep, sweeps.Points):
elif isinstance(sweep, cirq.Points):
return params_pb2.SingleSweep(
parameter_key=sweep.key, points=params_pb2.Points(points=sweep.points)
)
else:
raise ValueError(f'invalid single-parameter sweep: {sweep}')


def sweep_from_proto(param_sweep: params_pb2.ParameterSweep) -> sweeps.Sweep:
def sweep_from_proto(param_sweep: params_pb2.ParameterSweep) -> cirq.Sweep:
if param_sweep.HasField('sweep') and len(param_sweep.sweep.factors) > 0:
return sweeps.Product(
return cirq.Product(
*[_sweep_from_param_sweep_zip_proto(f) for f in param_sweep.sweep.factors]
)
return sweeps.UnitSweep
return cirq.UnitSweep


def _sweep_from_param_sweep_zip_proto(param_sweep_zip: params_pb2.ZipSweep) -> sweeps.Sweep:
def _sweep_from_param_sweep_zip_proto(param_sweep_zip: params_pb2.ZipSweep) -> cirq.Sweep:
if len(param_sweep_zip.sweeps) > 0:
return sweeps.Zip(
return cirq.Zip(
*[_sweep_from_single_param_sweep_proto(sweep) for sweep in param_sweep_zip.sweeps]
)
return sweeps.UnitSweep
return cirq.UnitSweep


def _sweep_from_single_param_sweep_proto(
single_param_sweep: params_pb2.SingleSweep,
) -> sweeps.Sweep:
) -> cirq.Sweep:
key = single_param_sweep.parameter_key
if single_param_sweep.HasField('points'):
points = single_param_sweep.points
return sweeps.Points(key, list(points.points))
return cirq.Points(key, list(points.points))
if single_param_sweep.HasField('linspace'):
sl = single_param_sweep.linspace
return sweeps.Linspace(key, sl.first_point, sl.last_point, sl.num_points)
return cirq.Linspace(key, sl.first_point, sl.last_point, sl.num_points)

raise ValueError('Single param sweep type undefined')
73 changes: 36 additions & 37 deletions cirq-google/cirq_google/api/v1/programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
import numpy as np
import sympy

from cirq import devices, ops, protocols, value, circuits
import cirq
from cirq_google.api.v1 import operations_pb2

if TYPE_CHECKING:
import cirq_google
import cirq


def _load_json_bool(b: Any):
Expand All @@ -32,46 +31,46 @@ def _load_json_bool(b: Any):


def gate_to_proto(
gate: 'cirq.Gate', qubits: Tuple['cirq.Qid', ...], delay: int
gate: cirq.Gate, qubits: Tuple[cirq.Qid, ...], delay: int
) -> operations_pb2.Operation:
if isinstance(gate, ops.MeasurementGate):
if isinstance(gate, cirq.MeasurementGate):
return operations_pb2.Operation(
incremental_delay_picoseconds=delay, measurement=_measure_to_proto(gate, qubits)
)

if isinstance(gate, ops.XPowGate):
if isinstance(gate, cirq.XPowGate):
if len(qubits) != 1:
# coverage: ignore
raise ValueError('Wrong number of qubits.')
return operations_pb2.Operation(
incremental_delay_picoseconds=delay, exp_w=_x_to_proto(gate, qubits[0])
)

if isinstance(gate, ops.YPowGate):
if isinstance(gate, cirq.YPowGate):
if len(qubits) != 1:
# coverage: ignore
raise ValueError('Wrong number of qubits.')
return operations_pb2.Operation(
incremental_delay_picoseconds=delay, exp_w=_y_to_proto(gate, qubits[0])
)

if isinstance(gate, ops.PhasedXPowGate):
if isinstance(gate, cirq.PhasedXPowGate):
if len(qubits) != 1:
# coverage: ignore
raise ValueError('Wrong number of qubits.')
return operations_pb2.Operation(
incremental_delay_picoseconds=delay, exp_w=_phased_x_to_proto(gate, qubits[0])
)

if isinstance(gate, ops.ZPowGate):
if isinstance(gate, cirq.ZPowGate):
if len(qubits) != 1:
# coverage: ignore
raise ValueError('Wrong number of qubits.')
return operations_pb2.Operation(
incremental_delay_picoseconds=delay, exp_z=_z_to_proto(gate, qubits[0])
)

if isinstance(gate, ops.CZPowGate):
if isinstance(gate, cirq.CZPowGate):
if len(qubits) != 2:
# coverage: ignore
raise ValueError('Wrong number of qubits.')
Expand All @@ -82,37 +81,37 @@ def gate_to_proto(
raise ValueError(f"Don't know how to serialize this gate: {gate!r}")


def _x_to_proto(gate: 'cirq.XPowGate', q: 'cirq.Qid') -> operations_pb2.ExpW:
def _x_to_proto(gate: cirq.XPowGate, q: cirq.Qid) -> operations_pb2.ExpW:
return operations_pb2.ExpW(
target=_qubit_to_proto(q),
axis_half_turns=_parameterized_value_to_proto(0),
half_turns=_parameterized_value_to_proto(gate.exponent),
)


def _y_to_proto(gate: 'cirq.YPowGate', q: 'cirq.Qid') -> operations_pb2.ExpW:
def _y_to_proto(gate: cirq.YPowGate, q: cirq.Qid) -> operations_pb2.ExpW:
return operations_pb2.ExpW(
target=_qubit_to_proto(q),
axis_half_turns=_parameterized_value_to_proto(0.5),
half_turns=_parameterized_value_to_proto(gate.exponent),
)


def _phased_x_to_proto(gate: 'cirq.PhasedXPowGate', q: 'cirq.Qid') -> operations_pb2.ExpW:
def _phased_x_to_proto(gate: cirq.PhasedXPowGate, q: cirq.Qid) -> operations_pb2.ExpW:
return operations_pb2.ExpW(
target=_qubit_to_proto(q),
axis_half_turns=_parameterized_value_to_proto(gate.phase_exponent),
half_turns=_parameterized_value_to_proto(gate.exponent),
)


def _z_to_proto(gate: 'cirq.ZPowGate', q: 'cirq.Qid') -> operations_pb2.ExpZ:
def _z_to_proto(gate: cirq.ZPowGate, q: cirq.Qid) -> operations_pb2.ExpZ:
return operations_pb2.ExpZ(
target=_qubit_to_proto(q), half_turns=_parameterized_value_to_proto(gate.exponent)
)


def _cz_to_proto(gate: 'cirq.CZPowGate', p: 'cirq.Qid', q: 'cirq.Qid') -> operations_pb2.Exp11:
def _cz_to_proto(gate: cirq.CZPowGate, p: cirq.Qid, q: cirq.Qid) -> operations_pb2.Exp11:
return operations_pb2.Exp11(
target1=_qubit_to_proto(p),
target2=_qubit_to_proto(q),
Expand All @@ -124,7 +123,7 @@ def _qubit_to_proto(qubit):
return operations_pb2.Qubit(row=qubit.row, col=qubit.col)


def _measure_to_proto(gate: 'cirq.MeasurementGate', qubits: Sequence['cirq.Qid']):
def _measure_to_proto(gate: cirq.MeasurementGate, qubits: Sequence[cirq.Qid]):
if len(qubits) == 0:
raise ValueError('Measurement gate on no qubits.')

Expand All @@ -139,12 +138,12 @@ def _measure_to_proto(gate: 'cirq.MeasurementGate', qubits: Sequence['cirq.Qid']
)
return operations_pb2.Measurement(
targets=[_qubit_to_proto(q) for q in qubits],
key=protocols.measurement_key(gate),
key=cirq.measurement_key(gate),
invert_mask=invert_mask,
)


def circuit_as_schedule_to_protos(circuit: 'cirq.Circuit') -> Iterator[operations_pb2.Operation]:
def circuit_as_schedule_to_protos(circuit: cirq.Circuit) -> Iterator[operations_pb2.Operation]:
"""Convert a circuit into an iterable of protos.
Args:
Expand All @@ -161,7 +160,7 @@ def circuit_as_schedule_to_protos(circuit: 'cirq.Circuit') -> Iterator[operation
delay = time_picos
else:
delay = time_picos - last_picos
op_proto = gate_to_proto(cast(ops.Gate, op.gate), op.qubits, delay)
op_proto = gate_to_proto(cast(cirq.Gate, op.gate), op.qubits, delay)
time_picos += 1
last_picos = time_picos
yield op_proto
Expand All @@ -170,13 +169,13 @@ def circuit_as_schedule_to_protos(circuit: 'cirq.Circuit') -> Iterator[operation
def circuit_from_schedule_from_protos(
device: 'cirq_google.XmonDevice',
ops: Iterable[operations_pb2.Operation],
) -> 'cirq.Circuit':
) -> cirq.Circuit:
"""Convert protos into a Circuit for the given device."""
result = []
for op in ops:
xmon_op = xmon_op_from_proto(op)
result.append(xmon_op)
return circuits.Circuit(result, device=device)
return cirq.Circuit(result, device=device)


def pack_results(measurements: Sequence[Tuple[str, np.ndarray]]) -> bytes:
Expand Down Expand Up @@ -252,7 +251,7 @@ def unpack_results(
return results


def is_native_xmon_op(op: 'cirq.Operation') -> bool:
def is_native_xmon_op(op: cirq.Operation) -> bool:
"""Check if the gate corresponding to an operation is a native xmon gate.
Args:
Expand All @@ -261,10 +260,10 @@ def is_native_xmon_op(op: 'cirq.Operation') -> bool:
Returns:
True if the operation is native to the xmon, false otherwise.
"""
return isinstance(op, ops.GateOperation) and is_native_xmon_gate(op.gate)
return isinstance(op, cirq.GateOperation) and is_native_xmon_gate(op.gate)


def is_native_xmon_gate(gate: 'cirq.Gate') -> bool:
def is_native_xmon_gate(gate: cirq.Gate) -> bool:
"""Check if a gate is a native xmon gate.
Args:
Expand All @@ -276,17 +275,17 @@ def is_native_xmon_gate(gate: 'cirq.Gate') -> bool:
return isinstance(
gate,
(
ops.CZPowGate,
ops.MeasurementGate,
ops.PhasedXPowGate,
ops.XPowGate,
ops.YPowGate,
ops.ZPowGate,
cirq.CZPowGate,
cirq.MeasurementGate,
cirq.PhasedXPowGate,
cirq.XPowGate,
cirq.YPowGate,
cirq.ZPowGate,
),
)


def xmon_op_from_proto(proto: operations_pb2.Operation) -> 'cirq.Operation':
def xmon_op_from_proto(proto: operations_pb2.Operation) -> cirq.Operation:
"""Convert the proto to the corresponding operation.
See protos in api/google/v1 for specification of the protos.
Expand All @@ -301,30 +300,30 @@ def xmon_op_from_proto(proto: operations_pb2.Operation) -> 'cirq.Operation':
qubit = _qubit_from_proto
if proto.HasField('exp_w'):
exp_w = proto.exp_w
return ops.PhasedXPowGate(
return cirq.PhasedXPowGate(
exponent=param(exp_w.half_turns),
phase_exponent=param(exp_w.axis_half_turns),
).on(qubit(exp_w.target))
if proto.HasField('exp_z'):
exp_z = proto.exp_z
return ops.Z(qubit(exp_z.target)) ** param(exp_z.half_turns)
return cirq.Z(qubit(exp_z.target)) ** param(exp_z.half_turns)
if proto.HasField('exp_11'):
exp_11 = proto.exp_11
return ops.CZ(qubit(exp_11.target1), qubit(exp_11.target2)) ** param(exp_11.half_turns)
return cirq.CZ(qubit(exp_11.target1), qubit(exp_11.target2)) ** param(exp_11.half_turns)
if proto.HasField('measurement'):
meas = proto.measurement
return ops.MeasurementGate(
return cirq.MeasurementGate(
num_qubits=len(meas.targets), key=meas.key, invert_mask=tuple(meas.invert_mask)
).on(*[qubit(q) for q in meas.targets])

raise ValueError(f'invalid operation: {proto}')


def _qubit_from_proto(proto: operations_pb2.Qubit):
return devices.GridQubit(row=proto.row, col=proto.col)
return cirq.GridQubit(row=proto.row, col=proto.col)


def _parameterized_value_from_proto(proto: operations_pb2.ParameterizedFloat) -> value.TParamVal:
def _parameterized_value_from_proto(proto: operations_pb2.ParameterizedFloat) -> cirq.TParamVal:
if proto.HasField('parameter_key'):
return sympy.Symbol(proto.parameter_key)
if proto.HasField('raw'):
Expand All @@ -336,7 +335,7 @@ def _parameterized_value_from_proto(proto: operations_pb2.ParameterizedFloat) ->
)


def _parameterized_value_to_proto(param: value.TParamVal) -> operations_pb2.ParameterizedFloat:
def _parameterized_value_to_proto(param: cirq.TParamVal) -> operations_pb2.ParameterizedFloat:
if isinstance(param, sympy.Symbol):
return operations_pb2.ParameterizedFloat(parameter_key=str(param.free_symbols.pop()))
else:
Expand Down

0 comments on commit 0311e17

Please sign in to comment.