Skip to content

Commit

Permalink
Changed simulators fallback to decompose_once and removed ancilla s…
Browse files Browse the repository at this point in the history
…upport from `DensityMatrixSimulator` (#6127)

* Fix bugs in strat_act_on_from_apply_decompose and improve support for qubit allocation within decompose

* Revert unrelated mypy change

* Fix mypy types and remove context argument from strat_act_on_from_apply_decompose

* Fix mypy error

* Update docstrings
  • Loading branch information
tanujkhattar committed Jun 8, 2023
1 parent ebc52d5 commit 0ef302f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 156 deletions.
16 changes: 0 additions & 16 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Expand Up @@ -285,22 +285,6 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits)
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
12 changes: 0 additions & 12 deletions cirq-core/cirq/sim/density_matrix_simulation_state_test.py
Expand Up @@ -123,15 +123,3 @@ def test_initial_state_bad_shape():
cirq.DensityMatrixSimulationState(
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
)


def test_remove_qubits():
"""Test the remove_qubits method."""
q1 = cirq.LineQubit(0)
q2 = cirq.LineQubit(1)
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])

new_state = state.remove_qubits([q1])

assert len(new_state.qubits) == 1
assert q1 not in new_state.qubits
67 changes: 33 additions & 34 deletions cirq-core/cirq/sim/simulation_state.py
Expand Up @@ -23,6 +23,7 @@
List,
Optional,
Sequence,
Set,
TypeVar,
TYPE_CHECKING,
Tuple,
Expand All @@ -31,8 +32,8 @@

import numpy as np

from cirq import protocols, value
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq import ops, protocols, value

from cirq.sim.simulation_state_base import SimulationStateBase

TState = TypeVar('TState', bound='cirq.QuantumStateRepresentation')
Expand Down Expand Up @@ -166,35 +167,35 @@ def create_merged_state(self) -> Self:
"""Creates a final merged state."""
return self

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
"""Add qubits to a new state space and take the kron product.
Note that only Density Matrix and State Vector simulators
override this function.
def add_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Add `qubits` in the `|0>` state to a new state space and take the kron product.
Args:
qubits: Sequence of qubits to be added.
Returns:
NotImplemented: If the subclass does not implement this method.
Raises:
ValueError: If a qubit being added is already tracked.
Self: A `cirq.SimulationState` with qubits added or `self` if there are no qubits to
add.
"""
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
if not qubits:
return self
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Remove qubits from the state space.
"""Remove `qubits` from the state space.
The qubits to be removed should be untangled from rest of the system and in the |0> state.
Args:
qubits: Sequence of qubits to be added.
qubits: Sequence of qubits to be removed.
Returns:
A new Simulation State with qubits removed. Or
`self` if there are no qubits to remove."""
if qubits is None or not qubits:
NotImplemented: If the subclass does not implement this method.
Self: A `cirq.SimulationState` with qubits removed or `self` if there are no qubits to
remove.
"""
if not qubits:
return self
return NotImplemented

Expand Down Expand Up @@ -325,25 +326,23 @@ def can_represent_mixed_states(self) -> bool:
def strat_act_on_from_apply_decompose(
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
if operations is None:
if isinstance(val, ops.Gate):
decomposed = protocols.decompose_once_with_qubits(val, qubits, flatten=False, default=None)
else:
decomposed = protocols.decompose_once(val, flatten=False, default=None)
if decomposed is None:
return NotImplemented
assert len(qubits1) == len(qubits)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map.update(dict(zip(qubits1, qubits)))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
all_ancilla: Set['cirq.Qid'] = set()
for operation in ops.flatten_to_ops(decomposed):
curr_ancilla = tuple(q for q in operation.qubits if q not in args.qubits)
args = args.add_qubits(curr_ancilla)
if args is NotImplemented:
return NotImplemented
all_ancilla.update(curr_ancilla)
protocols.act_on(operation, args)
args = args.remove_qubits(new_ancilla)
if args is NotImplemented: # coverage: ignore
raise TypeError( # coverage: ignore
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
) # coverage: ignore
args = args.remove_qubits(tuple(all_ancilla))
if args is NotImplemented:
raise TypeError(f"{type(args)} implements add_qubits but not remove_qubits.")
return True


Expand Down
146 changes: 52 additions & 94 deletions cirq-core/cirq/sim/simulation_state_test.py
Expand Up @@ -42,61 +42,26 @@ def _act_on_fallback_(
) -> bool:
return True


class AncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)


class AncillaH(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.H(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.H(ancilla) ** self._exponent


class AncillaY(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.Y(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.Y(ancilla) ** self._exponent
def add_qubits(self, qubits):
ret = super().add_qubits(qubits)
return self if NotImplemented else ret


class DelegatingAncillaZ(cirq.Gate):
def __init__(self, exponent=1):
def __init__(self, exponent=1, measure_ancilla: bool = False):
self._exponent = exponent
self._measure_ancilla = measure_ancilla

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
a = cirq.NamedQubit('a')
yield cirq.CX(qubits[0], a)
yield AncillaZ(self._exponent).on(a)
yield PhaseUsingCleanAncilla(self._exponent).on(a)
yield cirq.CX(qubits[0], a)
if self._measure_ancilla:
yield cirq.measure(a)


class Composite(cirq.Gate):
Expand All @@ -115,12 +80,23 @@ def test_measurements():

def test_decompose():
args = DummySimulationState()
assert (
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
is NotImplemented
assert simulation_state.strat_act_on_from_apply_decompose(
Composite(), args, [cirq.LineQubit(0)]
)


def test_decompose_for_gate_allocating_qubits_raises():
class Composite(cirq.testing.SingleQubitGate):
def _decompose_(self, qubits):
anc = cirq.NamedQubit("anc")
yield cirq.CNOT(*qubits, anc)

args = DummySimulationState()

with pytest.raises(TypeError, match="add_qubits but not remove_qubits"):
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])


def test_mapping():
args = DummySimulationState()
assert list(iter(args)) == cirq.LineQubit.range(2)
Expand Down Expand Up @@ -162,53 +138,35 @@ def test_field_getters():
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_z(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_y(exp):
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
def test_delegating_gate_unitary(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaY(exp).on(q))

control_circuit = cirq.Circuit(cirq.Y(q))
control_circuit.append(cirq.Y(q))
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_borrowable_qubit(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(cirq.X(q))
test_circuit.append(AncillaH(exp).on(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_delegating_gate_qubit(exp):
@pytest.mark.parametrize('exp', np.linspace(0, 2 * np.pi, 10))
def test_delegating_gate_channel(exp):
q = cirq.LineQubit(0)

test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))
test_circuit.append(DelegatingAncillaZ(exp, True).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
with pytest.raises(TypeError, match="DensityMatrixSimulator doesn't support"):
# TODO: This test should pass once we extend support to DensityMatrixSimulator.
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
Expand All @@ -221,7 +179,8 @@ def test_phase_using_dirty_ancilla(num_ancilla: int):
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
Expand All @@ -233,25 +192,24 @@ def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


def test_add_qubits_raise_value_error(num_ancilla=1):
q = cirq.LineQubit(0)
args = cirq.StateVectorSimulationState(qubits=[q])

with pytest.raises(ValueError, match='should not already be tracked.'):
args.add_qubits([q])
assert_test_circuit_for_dm_simulator(test_circuit, control_circuit)
assert_test_circuit_for_sv_simulator(test_circuit, control_circuit)


def test_remove_qubits_not_implemented(num_ancilla=1):
args = DummySimulationState()

assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented
def assert_test_circuit_for_dm_simulator(test_circuit, control_circuit) -> None:
# Density Matrix Simulator: For unitary gates, this fallbacks to `cirq.apply_channel`
# which recursively calls to `cirq.apply_unitary(decompose=True)`.
for split_untangled_states in [True, False]:
sim = cirq.DensityMatrixSimulator(split_untangled_states=split_untangled_states)
control_sim = sim.simulate(control_circuit).final_density_matrix
test_sim = sim.simulate(test_circuit).final_density_matrix
assert np.allclose(test_sim, control_sim)


def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
test_sim = eval(test_simulator)(test_circuit)
control_sim = eval(test_simulator)(control_circuit)
def assert_test_circuit_for_sv_simulator(test_circuit, control_circuit) -> None:
# State Vector Simulator.
for split_untangled_states in [True, False]:
sim = cirq.Simulator(split_untangled_states=split_untangled_states)
control_sim = sim.simulate(control_circuit).final_state_vector
test_sim = sim.simulate(test_circuit).final_state_vector
assert np.allclose(test_sim, control_sim)

0 comments on commit 0ef302f

Please sign in to comment.