Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ActOnArgs.kron/factor/transpose to reduce code duplication #4463

Merged
merged 10 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 5 additions & 11 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,11 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubit_map, self.M, self.simulation_options, self.grouping

def copy(self) -> 'MPSState':
state = MPSState(
qubits=self.qubits,
prng=self.prng,
simulation_options=self.simulation_options,
grouping=self.grouping,
log_of_measurement_results=self.log_of_measurement_results.copy(),
)
state.M = [x.copy() for x in self.M]
state.estimated_gate_error_list = self.estimated_gate_error_list
return state
def _on_copy(self, target: 'MPSState'):
target.simulation_options = self.simulation_options
target.grouping = self.grouping
target.M = [x.copy() for x in self.M]
target.estimated_gate_error_list = self.estimated_gate_error_list

def state_vector(self) -> np.ndarray:
"""Returns the full state vector.
Expand Down
75 changes: 66 additions & 9 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,15 @@ def __init__(
axes = ()
if log_of_measurement_results is None:
log_of_measurement_results = {}
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(qubits)}
self._set_qubits(qubits)
self._axes = tuple(axes)
self.prng = prng
self._log_of_measurement_results = log_of_measurement_results

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}

# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
Expand Down Expand Up @@ -105,31 +108,85 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""

@abc.abstractmethod
def copy(self: TSelf) -> TSelf:
"""Creates a copy of the object."""
args = copy.copy(self)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
return args

def _on_copy(self: TSelf, args: TSelf):
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
"""Subclasses should implement this with any additional state copy
functionality."""

def create_merged_state(self: TSelf) -> TSelf:
"""Creates a final merged state."""
return self

def kronecker_product(self: TSelf, other: TSelf) -> TSelf:
def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
"""Joins two state spaces together."""
raise NotImplementedError()
args = self if inplace else copy.copy(self)
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
self._on_kronecker_product(other, args)
args._set_qubits(self.qubits + other.qubits)
return args

def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf):
"""Subclasses should implement this with any additional state product
functionality, if supported."""

def factor(
self: TSelf,
qubits: Sequence['cirq.Qid'],
*,
validate=True,
atol=1e-07,
inplace=False,
) -> Tuple[TSelf, TSelf]:
"""Splits two state spaces after a measurement or reset."""
raise NotImplementedError()
extracted = copy.copy(self)
remainder = self if inplace else copy.copy(self)
self._on_factor(qubits, extracted, remainder, validate, atol)
extracted._set_qubits(qubits)
remainder._set_qubits([q for q in self.qubits if q not in qubits])
return extracted, remainder

def _on_factor(
self: TSelf,
qubits: Sequence['cirq.Qid'],
extracted: TSelf,
remainder: TSelf,
validate=True,
atol=1e-07,
):
"""Subclasses should implement this with any additional state factor
functionality, if supported."""

def transpose_to_qubit_order(
self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False
) -> TSelf:
"""Physically reindexes the state by the new basis.

Args:
qubits: The desired qubit order.
inplace: True to perform this operation inplace.

Returns:
The state with qubit order transposed and underlying representation
updated.

Raises:
ValueError: If the provided qubits do not match the existing ones.
"""
if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
args = self if inplace else copy.copy(self)
self._on_transpose_to_qubit_order(qubits, args)
args._set_qubits(qubits)
return args

def transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid']) -> TSelf:
"""Physically reindexes the state by the new basis."""
raise NotImplementedError()
def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf):
"""Subclasses should implement this with any additional state transpose
functionality, if supported."""

@property
def log_of_measurement_results(self) -> Dict[str, Any]:
Expand Down
35 changes: 9 additions & 26 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,17 @@ def _act_on_fallback_(
) -> bool:
return True

def kronecker_product(self, other: 'EmptyActOnArgs') -> 'EmptyActOnArgs':
return EmptyActOnArgs(
qubits=self.qubits + other.qubits,
logs=self.log_of_measurement_results,
)
def _on_copy(self, args):
pass

def factor(
self,
qubits: Sequence['cirq.Qid'],
*,
validate=True,
atol=1e-07,
) -> Tuple['EmptyActOnArgs', 'EmptyActOnArgs']:
extracted_args = EmptyActOnArgs(
qubits=qubits,
logs=self.log_of_measurement_results,
)
remainder_args = EmptyActOnArgs(
qubits=tuple(q for q in self.qubits if q not in qubits),
logs=self.log_of_measurement_results,
)
return extracted_args, remainder_args
def _on_kronecker_product(self, other, target):
pass

def transpose_to_qubit_order(self, qubits: Sequence['cirq.Qid']) -> 'EmptyActOnArgs':
return EmptyActOnArgs(
qubits=qubits,
logs=self.log_of_measurement_results,
)
def _on_transpose_to_qubit_order(self, qubits, target):
pass

def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
pass

def sample(self, qubits, repetitions=1, seed=None):
pass
Expand Down
13 changes: 10 additions & 3 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class DummyArgs(cirq.ActOnArgs):
def __init__(self):
super().__init__(qubits=cirq.LineQubit.range(2))

def copy(self):
pass

def sample(self, qubits, repetitions=1, seed=None):
pass

Expand Down Expand Up @@ -82,3 +79,13 @@ def test_rename_bad_dimensions():
args = DummyArgs()
with pytest.raises(ValueError, match='Cannot rename to different dimensions'):
args.rename(q0, q1)


def test_transpose_qubits():
q0, q1, q2 = cirq.LineQubit.range(3)
args = DummyArgs()
assert args.transpose_to_qubit_order((q1, q0)).qubits == (q1, q0)
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q2))
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q1, q1))
92 changes: 35 additions & 57 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,78 +131,56 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def copy(self) -> 'cirq.ActOnDensityMatrixArgs':
return ActOnDensityMatrixArgs(
target_tensor=self.target_tensor.copy(),
available_buffer=[b.copy() for b in self.available_buffer],
qubits=self.qubits,
qid_shape=self.qid_shape,
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results.copy(),
)
def _on_copy(self, target: 'ActOnDensityMatrixArgs'):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = [b.copy() for b in self.available_buffer]

def kronecker_product(
self, other: 'cirq.ActOnDensityMatrixArgs'
) -> 'cirq.ActOnDensityMatrixArgs':
def _on_kronecker_product(
self, other: 'ActOnDensityMatrixArgs', target: 'ActOnDensityMatrixArgs'
):
target_tensor = transformations.density_matrix_kronecker_product(
self.target_tensor, other.target_tensor
)
buffer = [np.empty_like(target_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=target_tensor,
available_buffer=buffer,
qubits=self.qubits + other.qubits,
qid_shape=target_tensor.shape[: int(target_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
target.target_tensor = target_tensor
target.available_buffer = [
np.empty_like(target_tensor) for _ in range(len(self.available_buffer))
]
target.qid_shape = target_tensor.shape[: int(target_tensor.ndim / 2)]

def factor(
def _on_factor(
self,
qubits: Sequence['cirq.Qid'],
*,
extracted: 'ActOnDensityMatrixArgs',
remainder: 'ActOnDensityMatrixArgs',
validate=True,
atol=1e-07,
) -> Tuple['cirq.ActOnDensityMatrixArgs', 'cirq.ActOnDensityMatrixArgs']:
):
axes = self.get_axes(qubits)
extracted_tensor, remainder_tensor = transformations.factor_density_matrix(
self.target_tensor, axes, validate=validate, atol=atol
)
buffer = [np.empty_like(extracted_tensor) for _ in self.available_buffer]
extracted_args = ActOnDensityMatrixArgs(
target_tensor=extracted_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=extracted_tensor.shape[: int(extracted_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
buffer = [np.empty_like(remainder_tensor) for _ in self.available_buffer]
remainder_args = ActOnDensityMatrixArgs(
target_tensor=remainder_tensor,
available_buffer=buffer,
qubits=tuple(q for q in self.qubits if q not in qubits),
qid_shape=remainder_tensor.shape[: int(remainder_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
)
return extracted_args, remainder_args
extracted.target_tensor = extracted_tensor
extracted.available_buffer = [
np.empty_like(extracted_tensor) for _ in self.available_buffer
]
extracted.qid_shape = extracted_tensor.shape[: int(extracted_tensor.ndim / 2)]
remainder.target_tensor = remainder_tensor
remainder.available_buffer = [
np.empty_like(remainder_tensor) for _ in self.available_buffer
]
remainder.qid_shape = remainder_tensor.shape[: int(remainder_tensor.ndim / 2)]

def transpose_to_qubit_order(
self, qubits: Sequence['cirq.Qid']
) -> 'cirq.ActOnDensityMatrixArgs':
def _on_transpose_to_qubit_order(
self, qubits: Sequence['cirq.Qid'], target: 'ActOnDensityMatrixArgs'
):
axes = self.get_axes(qubits)
axes = axes + [i + len(qubits) for i in axes]
new_tensor = np.moveaxis(self.target_tensor, axes, range(len(qubits) * 2))
buffer = [np.empty_like(new_tensor) for _ in self.available_buffer]
return ActOnDensityMatrixArgs(
target_tensor=new_tensor,
available_buffer=buffer,
qubits=qubits,
qid_shape=new_tensor.shape[: int(new_tensor.ndim / 2)],
prng=self.prng,
log_of_measurement_results=self.log_of_measurement_results,
new_tensor = transformations.transpose_density_matrix_to_axis_order(
self.target_tensor, axes
)
buffer = [np.empty_like(new_tensor) for _ in self.available_buffer]
target.target_tensor = new_tensor
target.available_buffer = buffer
target.qid_shape = new_tensor.shape[: int(new_tensor.ndim / 2)]

def sample(
self,
Expand Down Expand Up @@ -239,7 +217,7 @@ def _strat_apply_channel_to_state(
)
if result is None:
return NotImplemented
for i in range(3):
for i in range(len(args.available_buffer)):
if result is args.available_buffer[i]:
args.available_buffer[i] = args.target_tensor
args.target_tensor = result
Expand Down