Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions cirq-core/cirq/protocols/apply_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ class ApplyUnitaryArgs:
dtype as the target tensor.
axes: Which axes the unitary effect is being applied to (e.g. the
qubits that the gate is operating on).
subspaces: Which subspace (in the computational basis) the unitary
effect is being applied to, on each axis. By default it applies
to subspace 0..d-1 on each axis, where d is the dimension of the
unitary effect on that axis. Subspaces on each axis must be
representable as a slice, so the dimensions specified here need to
have a consistent step size.
"""

def __init__(
self, target_tensor: np.ndarray, available_buffer: np.ndarray, axes: Iterable[int]
self,
target_tensor: np.ndarray,
available_buffer: np.ndarray,
axes: Iterable[int],
subspaces: Optional[Sequence[Tuple[int, ...]]] = None,
):
"""Inits ApplyUnitaryArgs.

Expand All @@ -75,11 +85,27 @@ def __init__(
dtype as the target tensor.
axes: Which axes the unitary effect is being applied to (e.g. the
qubits that the gate is operating on).

subspaces: Which subspace (in the computational basis) the unitary
effect is being applied to, on each axis. By default it applies
to subspace 0..d-1 on each axis, where d is the dimension of
the unitary effect on that axis. Subspaces on each axis must be
representable as a slice, so the dimensions specified here need
to have a consistent step size.
Raises:
ValueError: If the subspace count does not equal the axis count, if
any subspace has zero dimensions, or if any subspace has
dimensions specified without a consistent step size.
"""
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.axes = tuple(axes)
if subspaces is not None:
if len(self.axes) != len(subspaces):
raise ValueError('Subspace count does not match axis count.')
for subspace, axis in zip(subspaces, self.axes):
if any(s >= target_tensor.shape[axis] for s in subspace):
raise ValueError('Subspace specified does not exist in axis.')
self.slices = None if subspaces is None else tuple(map(_to_slice, subspaces))

@staticmethod
def default(
Expand Down Expand Up @@ -125,7 +151,7 @@ def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs':
return ApplyUnitaryArgs(target_tensor, available_buffer, range(len(self.axes)))

def _for_operation_with_qid_shape(
self, indices: Iterable[int], qid_shape: Tuple[int, ...]
self, indices: Iterable[int], slices: Tuple[Union[int, slice], ...]
) -> 'ApplyUnitaryArgs':
"""Creates a sliced and transposed view of `self` appropriate for an
operation with shape `qid_shape` on qubits with the given indices.
Expand All @@ -138,14 +164,14 @@ def _for_operation_with_qid_shape(
Args:
indices: Integer indices into `self.axes` specifying which qubits
the operation applies to.
qid_shape: The qid shape of the operation, the expected number of
quantum levels in each qubit the operation applies to.
slices: The slices of the operation, the subdimension in each qubit
the operation applies to.

Returns: A new `ApplyUnitaryArgs` where `sub_args.target_tensor` and
`sub_args.available_buffer` are sliced and transposed views of
`self.target_tensor` and `self.available_buffer` respectively.
"""
slices = [slice(0, size) for size in qid_shape]
slices = tuple(size if isinstance(size, slice) else slice(0, size) for size in slices)
sub_axes = [self.axes[i] for i in indices]
axis_set = set(sub_axes)
other_axes = [axis for axis in range(len(self.target_tensor.shape)) if axis not in axis_set]
Expand Down Expand Up @@ -369,8 +395,12 @@ def _strat_apply_unitary_from_apply_unitary(
func = getattr(unitary_value, '_apply_unitary_', None)
if func is None:
return NotImplemented
op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes))
sub_args = args._for_operation_with_qid_shape(range(len(op_qid_shape)), op_qid_shape)
if args.slices is None:
op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes))
slices = tuple(slice(0, size) for size in op_qid_shape)
else:
slices = args.slices
sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices)
sub_result = func(sub_args)
if sub_result is NotImplemented or sub_result is None:
return sub_result
Expand All @@ -390,8 +420,15 @@ def _strat_apply_unitary_from_unitary(
if matrix is NotImplemented or matrix is None:
return matrix

val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
sub_args = args._for_operation_with_qid_shape(range(len(val_qid_shape)), val_qid_shape)
if args.slices is None:
val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
slices = tuple(slice(0, size) for size in val_qid_shape)
else:
slices = args.slices
val_qid_shape = tuple(
((s.step if s.stop is None else s.stop) - s.start) // (s.step or 1) for s in slices
)
sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices)
matrix = matrix.astype(sub_args.target_tensor.dtype)
if len(val_qid_shape) == 1 and val_qid_shape[0] <= 2:
# Special case for single-qubit, 2x2 or 1x1 operations.
Expand Down Expand Up @@ -557,3 +594,18 @@ def _incorporate_result_into_target(
return args.available_buffer
sub_args.target_tensor[...] = sub_result
return args.target_tensor


def _to_slice(subspace_def: Tuple[int, ...]):
if len(subspace_def) < 1:
raise ValueError(f'Subspace {subspace_def} has zero dimensions.')

if len(subspace_def) == 1:
return slice(subspace_def[0], subspace_def[0] + 1, 1)

step = subspace_def[1] - subspace_def[0]
for i in range(len(subspace_def) - 1):
if subspace_def[i + 1] - subspace_def[i] != step:
raise ValueError(f'Subspace {subspace_def} does not have consistent step size.')
stop = subspace_def[-1] + step
return slice(subspace_def[0], stop if stop >= 0 else None, step)
252 changes: 252 additions & 0 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,258 @@ def _unitary_(self):
)


# fmt: off
def test_subspace_size_2():
result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((4,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((4,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
),
atol=1e-8,
)


def test_subspaces_size_3():
plus_one_mod_3_gate = cirq.XPowGate(dimension=3)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(2, 1, 0)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((4,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((4,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2, 3)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
]
),
atol=1e-8,
)


def test_subspaces_size_1():
phase_gate = cirq.MatrixGate(np.array([[1j]]))

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(0,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1j, 0],
[0, 1],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(1,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0],
[0, 1j],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=np.array([[0, 1], [1, 0]], dtype=np.complex64),
available_buffer=np.zeros((2, 2), dtype=np.complex64),
axes=(0,),
subspaces=[(1,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1],
[1j, 0],
]
),
atol=1e-8,
)
# fmt: on


def test_invalid_subspaces():
with pytest.raises(ValueError, match='Subspace specified does not exist in axis'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
)
with pytest.raises(ValueError, match='Subspace count does not match axis count'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1), (0, 1)],
)
with pytest.raises(ValueError, match='has zero dimensions'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[()],
)
with pytest.raises(ValueError, match='does not have consistent step size'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 2, 1)],
)


def test_incorporate_result_not_view():
tensor = np.zeros((2, 2))
tensor2 = np.zeros((2, 2))
Expand Down