Skip to content

Commit

Permalink
Refactor ActOnArgs.kron/factor/transpose to reduce code duplication (q…
Browse files Browse the repository at this point in the history
…uantumlib#4463)

Step 3 for feedforward, outlined in https://tinyurl.com/cirq-feedforward.

This PR reduces duplication by moving redundant code from ActOnArgs.kron etc, such as `qubits = self.qubits + other.qubits` into the ActOnArgs base class, and then using handlers in subclasses to append subclass-specific information. It also updates these methods with an `inplace` option to allow them to be consistent with the swap/rename methods added in quantumlib#4169

@95-martin-orion has the most context here.
  • Loading branch information
daxfohl authored and rht committed May 1, 2023
1 parent 5c19629 commit b0f35a2
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 208 deletions.
16 changes: 5 additions & 11 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,17 +301,11 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubit_map, self.M, self.simulation_options, self.grouping

def copy(self) -> 'MPSState':
state = MPSState(
qubits=self.qubits,
prng=self.prng,
simulation_options=self.simulation_options,
grouping=self.grouping,
log_of_measurement_results=self.log_of_measurement_results.copy(),
)
state.M = [x.copy() for x in self.M]
state.estimated_gate_error_list = self.estimated_gate_error_list
return state
def _on_copy(self, target: 'MPSState'):
target.simulation_options = self.simulation_options
target.grouping = self.grouping
target.M = [x.copy() for x in self.M]
target.estimated_gate_error_list = self.estimated_gate_error_list

def state_vector(self) -> np.ndarray:
"""Returns the full state vector.
Expand Down
75 changes: 66 additions & 9 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,15 @@ def __init__(
axes = ()
if log_of_measurement_results is None:
log_of_measurement_results = {}
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(qubits)}
self._set_qubits(qubits)
self._axes = tuple(axes)
self.prng = prng
self._log_of_measurement_results = log_of_measurement_results

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}

# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
Expand Down Expand Up @@ -105,31 +108,85 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""

@abc.abstractmethod
def copy(self: TSelf) -> TSelf:
"""Creates a copy of the object."""
args = copy.copy(self)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
return args

def _on_copy(self: TSelf, args: TSelf):
"""Subclasses should implement this with any additional state copy
functionality."""

def create_merged_state(self: TSelf) -> TSelf:
"""Creates a final merged state."""
return self

def kronecker_product(self: TSelf, other: TSelf) -> TSelf:
def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
"""Joins two state spaces together."""
raise NotImplementedError()
args = self if inplace else copy.copy(self)
self._on_kronecker_product(other, args)
args._set_qubits(self.qubits + other.qubits)
return args

def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf):
"""Subclasses should implement this with any additional state product
functionality, if supported."""

def factor(
self: TSelf,
qubits: Sequence['cirq.Qid'],
*,
validate=True,
atol=1e-07,
inplace=False,
) -> Tuple[TSelf, TSelf]:
"""Splits two state spaces after a measurement or reset."""
raise NotImplementedError()
extracted = copy.copy(self)
remainder = self if inplace else copy.copy(self)
self._on_factor(qubits, extracted, remainder, validate, atol)
extracted._set_qubits(qubits)
remainder._set_qubits([q for q in self.qubits if q not in qubits])
return extracted, remainder

def _on_factor(
self: TSelf,
qubits: Sequence['cirq.Qid'],
extracted: TSelf,
remainder: TSelf,
validate=True,
atol=1e-07,
):
"""Subclasses should implement this with any additional state factor
functionality, if supported."""

def transpose_to_qubit_order(
self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False
) -> TSelf:
"""Physically reindexes the state by the new basis.
Args:
qubits: The desired qubit order.
inplace: True to perform this operation inplace.
Returns:
The state with qubit order transposed and underlying representation
updated.
Raises:
ValueError: If the provided qubits do not match the existing ones.
"""
if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
args = self if inplace else copy.copy(self)
self._on_transpose_to_qubit_order(qubits, args)
args._set_qubits(qubits)
return args

def transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid']) -> TSelf:
"""Physically reindexes the state by the new basis."""
raise NotImplementedError()
def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf):
"""Subclasses should implement this with any additional state transpose
functionality, if supported."""

@property
def log_of_measurement_results(self) -> Dict[str, Any]:
Expand Down
35 changes: 9 additions & 26 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,17 @@ def _act_on_fallback_(
) -> bool:
return True

def kronecker_product(self, other: 'EmptyActOnArgs') -> 'EmptyActOnArgs':
return EmptyActOnArgs(
qubits=self.qubits + other.qubits,
logs=self.log_of_measurement_results,
)
def _on_copy(self, args):
pass

def factor(
self,
qubits: Sequence['cirq.Qid'],
*,
validate=True,
atol=1e-07,
) -> Tuple['EmptyActOnArgs', 'EmptyActOnArgs']:
extracted_args = EmptyActOnArgs(
qubits=qubits,
logs=self.log_of_measurement_results,
)
remainder_args = EmptyActOnArgs(
qubits=tuple(q for q in self.qubits if q not in qubits),
logs=self.log_of_measurement_results,
)
return extracted_args, remainder_args
def _on_kronecker_product(self, other, target):
pass

def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'EmptyActOnArgs':
return EmptyActOnArgs(
qubits=qubits,
logs=self.log_of_measurement_results,
)
def _on_transpose_to_qubit_order(self, qubits, target):
pass

def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
pass

def sample(self, qubits, repetitions=1, seed=None):
pass
Expand Down
13 changes: 10 additions & 3 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class DummyArgs(cirq.ActOnArgs):
def __init__(self):
super().__init__(qubits=cirq.LineQubit.range(2))

def copy(self):
pass

def sample(self, qubits, repetitions=1, seed=None):
pass

Expand Down Expand Up @@ -82,3 +79,13 @@ def test_rename_bad_dimensions():
args = DummyArgs()
with pytest.raises(ValueError, match='Cannot rename to different dimensions'):
args.rename(q0, q1)


def test_transpose_qubits():
q0, q1, q2 = cirq.LineQubit.range(3)
args = DummyArgs()
assert args.transpose_to_qubit_order((q1, q0)).qubits == (q1, q0)
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q2))
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q1, q1))
92 changes: 35 additions & 57 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,78 +131,56 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def copy(self) -> 'cirq.ActOnDensityMatrixArgs':
return ActOnDensityMatrixArgs(
target_tensor=self.target_tensor.copy(),
available_buffer=[b.copy() for b in self.available_buffer],
qubits=self.qubits,
qid_shape=self.qid_shape,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
)
def _on_copy(self, target: 'ActOnDensityMatrixArgs'):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = [b.copy() for b in self.available_buffer]

def kronecker_product(
self, other: 'cirq.ActOnDensityMatrixArgs'
) -> 'cirq.ActOnDensityMatrixArgs':
def _on_kronecker_product(
self, other: 'ActOnDensityMatrixArgs', target: 'ActOnDensityMatrixArgs'
):
target_tensor = transformations.density_matrix_kronecker_product(
self.target_tensor, other.target_tensor
)
buffer = [np.empty_like(target_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=target_tensor,
available_buffer=buffer,
qubits=self.qubits + other.qubits,
qid_shape=target_tensor.shape[: int(target_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
target.target_tensor = target_tensor
target.available_buffer = [
np.empty_like(target_tensor) for _ in range(len(self.available_buffer))
]
target.qid_shape = target_tensor.shape[: int(target_tensor.ndim / 2)]

def factor(
def _on_factor(
self,
qubits: Sequence['cirq.Qid'],
*,
extracted: 'ActOnDensityMatrixArgs',
remainder: 'ActOnDensityMatrixArgs',
validate=True,
atol=1e-07,
) -> Tuple['cirq.ActOnDensityMatrixArgs', 'cirq.ActOnDensityMatrixArgs']:
):
axes = self.get_axes(qubits)
extracted_tensor, remainder_tensor = transformations.factor_density_matrix(
self.target_tensor, axes, validate=validate, atol=atol
)
buffer = [np.empty_like(extracted_tensor) for _ in self.available_buffer]
extracted_args = ActOnDensityMatrixArgs(
target_tensor=extracted_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=extracted_tensor.shape[: int(extracted_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
buffer = [np.empty_like(remainder_tensor) for _ in self.available_buffer]
remainder_args = ActOnDensityMatrixArgs(
target_tensor=remainder_tensor,
available_buffer=buffer,
qubits=tuple(q for q in self.qubits if q not in qubits),
qid_shape=remainder_tensor.shape[: int(remainder_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
return extracted_args, remainder_args
extracted.target_tensor = extracted_tensor
extracted.available_buffer = [
np.empty_like(extracted_tensor) for _ in self.available_buffer
]
extracted.qid_shape = extracted_tensor.shape[: int(extracted_tensor.ndim / 2)]
remainder.target_tensor = remainder_tensor
remainder.available_buffer = [
np.empty_like(remainder_tensor) for _ in self.available_buffer
]
remainder.qid_shape = remainder_tensor.shape[: int(remainder_tensor.ndim / 2)]

def transpose_to_qubit_order(
self, qubits: Sequence['cirq.Qid']
) -> 'cirq.ActOnDensityMatrixArgs':
def _on_transpose_to_qubit_order(
self, qubits: Sequence['cirq.Qid'], target: 'ActOnDensityMatrixArgs'
):
axes = self.get_axes(qubits)
axes = axes + [i + len(qubits) for i in axes]
new_tensor = np.moveaxis(self.target_tensor, axes, range(len(qubits) * 2))
buffer = [np.empty_like(new_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=new_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=new_tensor.shape[: int(new_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
new_tensor = transformations.transpose_density_matrix_to_axis_order(
self.target_tensor, axes
)
buffer = [np.empty_like(new_tensor) for _ in self.available_buffer]
target.target_tensor = new_tensor
target.available_buffer = buffer
target.qid_shape = new_tensor.shape[: int(new_tensor.ndim / 2)]

def sample(
self,
Expand Down Expand Up @@ -239,7 +217,7 @@ def _strat_apply_channel_to_state(
)
if result is None:
return NotImplemented
for i in range(3):
for i in range(len(args.available_buffer)):
if result is args.available_buffer[i]:
args.available_buffer[i] = args.target_tensor
args.target_tensor = result
Expand Down

0 comments on commit b0f35a2

Please sign in to comment.