diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol.py b/cirq-core/cirq/protocols/apply_unitary_protocol.py index 58c8e47d527..5c9cc396cdc 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol.py @@ -58,10 +58,20 @@ class ApplyUnitaryArgs: dtype as the target tensor. axes: Which axes the unitary effect is being applied to (e.g. the qubits that the gate is operating on). + subspaces: Which subspace (in the computational basis) the unitary + effect is being applied to, on each axis. By default it applies + to subspace 0..d-1 on each axis, where d is the dimension of the + unitary effect on that axis. Subspaces on each axis must be + representable as a slice, so the dimensions specified here need to + have a consistent step size. """ def __init__( - self, target_tensor: np.ndarray, available_buffer: np.ndarray, axes: Iterable[int] + self, + target_tensor: np.ndarray, + available_buffer: np.ndarray, + axes: Iterable[int], + subspaces: Optional[Sequence[Tuple[int, ...]]] = None, ): """Inits ApplyUnitaryArgs. @@ -75,11 +85,27 @@ def __init__( dtype as the target tensor. axes: Which axes the unitary effect is being applied to (e.g. the qubits that the gate is operating on). - + subspaces: Which subspace (in the computational basis) the unitary + effect is being applied to, on each axis. By default it applies + to subspace 0..d-1 on each axis, where d is the dimension of + the unitary effect on that axis. Subspaces on each axis must be + representable as a slice, so the dimensions specified here need + to have a consistent step size. + Raises: + ValueError: If the subspace count does not equal the axis count, if + any subspace has zero dimensions, or if any subspace has + dimensions specified without a consistent step size. """ self.target_tensor = target_tensor self.available_buffer = available_buffer self.axes = tuple(axes) + if subspaces is not None: + if len(self.axes) != len(subspaces): + raise ValueError('Subspace count does not match axis count.') + for subspace, axis in zip(subspaces, self.axes): + if any(s >= target_tensor.shape[axis] for s in subspace): + raise ValueError('Subspace specified does not exist in axis.') + self.slices = None if subspaces is None else tuple(map(_to_slice, subspaces)) @staticmethod def default( @@ -125,7 +151,7 @@ def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs': return ApplyUnitaryArgs(target_tensor, available_buffer, range(len(self.axes))) def _for_operation_with_qid_shape( - self, indices: Iterable[int], qid_shape: Tuple[int, ...] + self, indices: Iterable[int], slices: Tuple[Union[int, slice], ...] ) -> 'ApplyUnitaryArgs': """Creates a sliced and transposed view of `self` appropriate for an operation with shape `qid_shape` on qubits with the given indices. @@ -138,14 +164,14 @@ def _for_operation_with_qid_shape( Args: indices: Integer indices into `self.axes` specifying which qubits the operation applies to. - qid_shape: The qid shape of the operation, the expected number of - quantum levels in each qubit the operation applies to. + slices: The slices of the operation, the subdimension in each qubit + the operation applies to. Returns: A new `ApplyUnitaryArgs` where `sub_args.target_tensor` and `sub_args.available_buffer` are sliced and transposed views of `self.target_tensor` and `self.available_buffer` respectively. """ - slices = [slice(0, size) for size in qid_shape] + slices = tuple(size if isinstance(size, slice) else slice(0, size) for size in slices) sub_axes = [self.axes[i] for i in indices] axis_set = set(sub_axes) other_axes = [axis for axis in range(len(self.target_tensor.shape)) if axis not in axis_set] @@ -369,8 +395,12 @@ def _strat_apply_unitary_from_apply_unitary( func = getattr(unitary_value, '_apply_unitary_', None) if func is None: return NotImplemented - op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes)) - sub_args = args._for_operation_with_qid_shape(range(len(op_qid_shape)), op_qid_shape) + if args.slices is None: + op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes)) + slices = tuple(slice(0, size) for size in op_qid_shape) + else: + slices = args.slices + sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices) sub_result = func(sub_args) if sub_result is NotImplemented or sub_result is None: return sub_result @@ -390,8 +420,15 @@ def _strat_apply_unitary_from_unitary( if matrix is NotImplemented or matrix is None: return matrix - val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes)) - sub_args = args._for_operation_with_qid_shape(range(len(val_qid_shape)), val_qid_shape) + if args.slices is None: + val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes)) + slices = tuple(slice(0, size) for size in val_qid_shape) + else: + slices = args.slices + val_qid_shape = tuple( + ((s.step if s.stop is None else s.stop) - s.start) // (s.step or 1) for s in slices + ) + sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices) matrix = matrix.astype(sub_args.target_tensor.dtype) if len(val_qid_shape) == 1 and val_qid_shape[0] <= 2: # Special case for single-qubit, 2x2 or 1x1 operations. @@ -557,3 +594,18 @@ def _incorporate_result_into_target( return args.available_buffer sub_args.target_tensor[...] = sub_result return args.target_tensor + + +def _to_slice(subspace_def: Tuple[int, ...]): + if len(subspace_def) < 1: + raise ValueError(f'Subspace {subspace_def} has zero dimensions.') + + if len(subspace_def) == 1: + return slice(subspace_def[0], subspace_def[0] + 1, 1) + + step = subspace_def[1] - subspace_def[0] + for i in range(len(subspace_def) - 1): + if subspace_def[i + 1] - subspace_def[i] != step: + raise ValueError(f'Subspace {subspace_def} does not have consistent step size.') + stop = subspace_def[-1] + step + return slice(subspace_def[0], stop if stop >= 0 else None, step) diff --git a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py index ef675180569..0f41ec77d20 100644 --- a/cirq-core/cirq/protocols/apply_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/apply_unitary_protocol_test.py @@ -422,6 +422,258 @@ def _unitary_(self): ) +# fmt: off +def test_subspace_size_2(): + result = cirq.apply_unitary( + unitary_value=cirq.X, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(0, 1)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=cirq.X, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(0, 2)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=cirq.X, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(1, 2)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [1, 0, 0], + [0, 0, 1], + [0, 1, 0], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=cirq.X, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((4,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((4,), dtype=np.complex64), + axes=(0,), + subspaces=[(1, 2)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ] + ), + atol=1e-8, + ) + + +def test_subspaces_size_3(): + plus_one_mod_3_gate = cirq.XPowGate(dimension=3) + + result = cirq.apply_unitary( + unitary_value=plus_one_mod_3_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(0, 1, 2)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=plus_one_mod_3_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(2, 1, 0)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=plus_one_mod_3_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((4,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((4,), dtype=np.complex64), + axes=(0,), + subspaces=[(1, 2, 3)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [1, 0, 0, 0], + [0, 0, 0, 1], + [0, 1, 0, 0], + [0, 0, 1, 0], + ] + ), + atol=1e-8, + ) + + +def test_subspaces_size_1(): + phase_gate = cirq.MatrixGate(np.array([[1j]])) + + result = cirq.apply_unitary( + unitary_value=phase_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((2,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((2,), dtype=np.complex64), + axes=(0,), + subspaces=[(0,)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [1j, 0], + [0, 1], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=phase_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((2,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((2,), dtype=np.complex64), + axes=(0,), + subspaces=[(1,)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [1, 0], + [0, 1j], + ] + ), + atol=1e-8, + ) + + result = cirq.apply_unitary( + unitary_value=phase_gate, + args=cirq.ApplyUnitaryArgs( + target_tensor=np.array([[0, 1], [1, 0]], dtype=np.complex64), + available_buffer=np.zeros((2, 2), dtype=np.complex64), + axes=(0,), + subspaces=[(1,)], + ), + ) + np.testing.assert_allclose( + result, + np.array( + [ + [0, 1], + [1j, 0], + ] + ), + atol=1e-8, + ) +# fmt: on + + +def test_invalid_subspaces(): + with pytest.raises(ValueError, match='Subspace specified does not exist in axis'): + _ = cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((2,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((2,), dtype=np.complex64), + axes=(0,), + subspaces=[(1, 2)], + ) + with pytest.raises(ValueError, match='Subspace count does not match axis count'): + _ = cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((2,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((2,), dtype=np.complex64), + axes=(0,), + subspaces=[(0, 1), (0, 1)], + ) + with pytest.raises(ValueError, match='has zero dimensions'): + _ = cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((2,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((2,), dtype=np.complex64), + axes=(0,), + subspaces=[()], + ) + with pytest.raises(ValueError, match='does not have consistent step size'): + _ = cirq.ApplyUnitaryArgs( + target_tensor=cirq.eye_tensor((3,), dtype=np.complex64), + available_buffer=cirq.eye_tensor((3,), dtype=np.complex64), + axes=(0,), + subspaces=[(0, 2, 1)], + ) + + def test_incorporate_result_not_view(): tensor = np.zeros((2, 2)) tensor2 = np.zeros((2, 2))