diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 48c72551b63..12bfb2b43d9 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -217,6 +217,8 @@ InterchangeableQubitsGate, ISWAP, ISwapPowGate, + KetBra, + KetBraSum, LinearCombinationOfGates, LinearCombinationOfOperations, MatrixGate, diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 33a6848ed65..f89714e67ee 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -89,6 +89,8 @@ def two_qubit_matrix_gate(matrix): 'IdentityGate': cirq.IdentityGate, 'IdentityOperation': _identity_operation_from_dict, 'InitObsSetting': cirq.work.InitObsSetting, + 'KetBra': cirq.KetBra, + 'KetBraSum': cirq.KetBraSum, 'LinearDict': cirq.LinearDict, 'LineQubit': cirq.LineQubit, 'LineQid': cirq.LineQid, diff --git a/cirq-core/cirq/ops/__init__.py b/cirq-core/cirq/ops/__init__.py index 02ea527eb9c..6758f9f2f8f 100644 --- a/cirq-core/cirq/ops/__init__.py +++ b/cirq-core/cirq/ops/__init__.py @@ -112,6 +112,11 @@ IdentityGate, ) +from cirq.ops.ketbrasum import ( + KetBra, + KetBraSum, +) + from cirq.ops.global_phase_op import ( GlobalPhaseOperation, ) diff --git a/cirq-core/cirq/ops/ketbrasum.py b/cirq-core/cirq/ops/ketbrasum.py new file mode 100644 index 00000000000..43712d375e0 --- /dev/null +++ b/cirq-core/cirq/ops/ketbrasum.py @@ -0,0 +1,225 @@ +from typing import ( + Any, + Dict, + Iterable, + Mapping, + List, + Optional, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) + +import numpy as np + +from cirq import linalg, value +from cirq.ops import raw_types +from cirq.qis import states +from cirq.qis import STATE_VECTOR_LIKE + +if TYPE_CHECKING: + import cirq + from cirq import protocols + +KetBraKey = TypeVar('KetBraKey', bound=Union[raw_types.Qid, Tuple[raw_types.Qid]]) + + +def qid_shape_from_ket_bra_key(ket_bra_key: KetBraKey): + if isinstance(ket_bra_key, tuple): + return [qid.dimension for qid in ket_bra_key] + else: + return [ket_bra_key.dimension] + + +def get_dims_from_qid_map(qid_map: Mapping[raw_types.Qid, int]): + dims = sorted([(i, qid.dimension) for qid, i in qid_map.items()]) + return [x[1] for x in dims] + + +def get_qid_indices(qid_map: Mapping[raw_types.Qid, int], ket_bra_key: KetBraKey): + if isinstance(ket_bra_key, raw_types.Qid): + qid = ket_bra_key + if qid not in qid_map: + raise ValueError(f"Missing qid: {qid}") + return [qid_map[qid]] + else: + idx = [] + for qid in ket_bra_key: + if qid not in qid_map: + raise ValueError(f"Missing qid: {qid}") + idx.append(qid_map[qid]) + return idx + + +@value.value_equality +class KetBra: + def __init__( + self, + ket: Optional[STATE_VECTOR_LIKE] = None, + bra: Optional[STATE_VECTOR_LIKE] = None, + ket_bra_list: Optional[List[Tuple[STATE_VECTOR_LIKE, STATE_VECTOR_LIKE]]] = None, + ): + + if ket_bra_list is not None: + self.ket_bra_list = ket_bra_list + else: + self.ket_bra_list = [] + if ket is not None and bra is not None: + self.ket_bra_list.append( + ( + ket, + bra, + ) + ) + + def KetBra(self, ket: STATE_VECTOR_LIKE, bra: STATE_VECTOR_LIKE): + self.ket_bra_list.append( + ( + ket, + bra, + ) + ) + return self + + def __iter__(self): + return iter(self.ket_bra_list) + + def _json_dict_(self) -> Dict[str, Any]: + return { + 'cirq_type': self.__class__.__name__, + 'ket_bra_list': [list(ket_bra) for ket_bra in self.ket_bra_list], + } + + @classmethod + def _from_json_dict_(cls, ket_bra_list, **kwargs): + ket_bra_list = [tuple(ket_bra) for ket_bra in ket_bra_list] + return cls(ket_bra_list=ket_bra_list) + + def _value_equality_values_(self) -> Any: + return tuple(self.ket_bra_list) + + +@value.value_equality +class KetBraSum: + """A generic operation specified as a list of |ket> Dict[KetBraKey, KetBra]: + return self._ket_bra_dict + + def _op_matrix(self, ket_bra_key: KetBraKey) -> np.ndarray: + # TODO(tonybruguier): Speed up computation when the ket and bra are + # encoded as integers. This probably means not calling this function at + # all, as encoding a matrix with a single non-zero entry is not + # efficient. + qid_shape = qid_shape_from_ket_bra_key(ket_bra_key) + + P = 0 + for ket_bra in self._ket_bra_dict[ket_bra_key]: + ket = states.to_valid_state_vector(ket_bra[0], qid_shape=qid_shape) + bra = states.to_valid_state_vector(ket_bra[1], qid_shape=qid_shape).conj() + P = P + np.einsum('i,j->ij', ket, bra) + + return P + + def matrix(self, ket_bra_keys: Optional[Iterable[KetBraKey]] = None) -> Iterable[np.ndarray]: + ket_bra_keys = self._ket_bra_dict.keys() if ket_bra_keys is None else ket_bra_keys + factors = [] + for ket_bra_key in ket_bra_keys: + if ket_bra_key not in self._ket_bra_dict.keys(): + qid_shape = qid_shape_from_ket_bra_key(ket_bra_key) + factors.append(np.eye(np.prod(qid_shape))) + else: + factors.append(self._op_matrix(ket_bra_key)) + return linalg.kron(*factors) + + def expectation_from_state_vector( + self, + state_vector: np.ndarray, + qid_map: Mapping[raw_types.Qid, int], + *, + atol: float = 1e-7, + check_preconditions: bool = True, + ) -> float: + dims = get_dims_from_qid_map(qid_map) + state_vector = state_vector.reshape(dims) + + for ket_bra_key in self._ket_bra_dict.keys(): + idx = get_qid_indices(qid_map, ket_bra_key) + op_dims = qid_shape_from_ket_bra_key(ket_bra_key) + nr = len(idx) + + P = self._op_matrix(ket_bra_key) + P = np.reshape(P, op_dims * 2) + + state_vector = np.tensordot(P, state_vector, axes=(range(nr, 2 * nr), idx)) + state_vector = np.moveaxis(state_vector, range(nr), idx) + + state_vector = np.reshape(state_vector, np.prod(dims)) + return np.dot(state_vector, state_vector.conj()) + + def expectation_from_density_matrix( + self, + state: np.ndarray, + qid_map: Mapping[raw_types.Qid, int], + *, + atol: float = 1e-7, + check_preconditions: bool = True, + ) -> float: + dims = get_dims_from_qid_map(qid_map) + state = state.reshape(dims * 2) + + for ket_bra_key in self._ket_bra_dict.keys(): + idx = get_qid_indices(qid_map, ket_bra_key) + op_dims = qid_shape_from_ket_bra_key(ket_bra_key) + nr = len(idx) + + P = self._op_matrix(ket_bra_key) + P = np.reshape(P, op_dims * 2) + + state = np.tensordot(P, state, axes=(range(nr, 2 * nr), idx)) + state = np.moveaxis(state, range(nr), idx) + state = np.tensordot(state, P.T.conj(), axes=([len(dims) + i for i in idx], range(nr))) + state = np.moveaxis(state, range(-nr, 0), [len(dims) + i for i in idx]) + + state = np.reshape(state, [np.prod(dims)] * 2) + return np.trace(state) + + def __repr__(self) -> str: + return f"cirq.KetBraSum(ket_bra_dict={self._ket_bra_dict})" + + def _json_dict_(self) -> Dict[str, Any]: + encoded_dict = {k: [(t[0], t[1]) for t in v] for k, v in self._ket_bra_dict.items()} + return { + 'cirq_type': self.__class__.__name__, + # JSON requires mappings to have string keys. + 'ket_bra_dict': list(encoded_dict.items()), + } + + @classmethod + def _from_json_dict_(cls, ket_bra_dict, **kwargs): + encoded_dict = dict(ket_bra_dict) + return cls(ket_bra_dict={k: [tuple(t) for t in v] for k, v in encoded_dict.items()}) + + def _value_equality_values_(self) -> Any: + ket_bra_dict = sorted(self._ket_bra_dict.items()) + encoded_dict = {k: tuple([(t[0], t[1]) for t in v]) for k, v in ket_bra_dict} + return tuple(encoded_dict.items()) diff --git a/cirq-core/cirq/ops/ketbrasum_test.py b/cirq-core/cirq/ops/ketbrasum_test.py new file mode 100644 index 00000000000..e91946b6a97 --- /dev/null +++ b/cirq-core/cirq/ops/ketbrasum_test.py @@ -0,0 +1,296 @@ +from itertools import permutations +import math + +import numpy as np +import pytest + +import cirq + + +def test_ket_bra_building(): + ketbra = cirq.KetBra(1, 2).KetBra(3, 4) + assert ketbra.ket_bra_list == [ + ( + 1, + 2, + ), + ( + 3, + 4, + ), + ] + + +def test_ket_bra_qid(): + q0 = cirq.NamedQubit('q0') + q1 = cirq.NamedQubit('q1') + + zero_projector = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + one_projector = cirq.KetBraSum({q0: cirq.KetBra(1, 1)}) + not_a_projector = cirq.KetBraSum({q0: cirq.KetBra(0, 1)}) + two_qids = cirq.KetBraSum({(q0, q1): cirq.KetBra(1, 3)}) + + np.testing.assert_allclose(zero_projector.matrix(), [[1.0, 0.0], [0.0, 0.0]]) + np.testing.assert_allclose(one_projector.matrix(), [[0.0, 0.0], [0.0, 1.0]]) + np.testing.assert_allclose(not_a_projector.matrix(), [[0.0, 1.0], [0.0, 0.0]]) + np.testing.assert_allclose( + two_qids.matrix(), + [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ) + + +def test_ket_bra_from_np_array(): + q0 = cirq.NamedQubit('q0') + + zero_projector = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + np.testing.assert_allclose(zero_projector.matrix(), [[1.0, 0.0], [0.0, 0.0]]) + + +def test_ket_bra_plus(): + q0 = cirq.NamedQubit('q0') + + plus = [1.0 / math.sqrt(2), 1.0 / math.sqrt(2)] + plus_projector = cirq.KetBraSum( + { + q0: cirq.KetBra( + plus, + plus, + ) + } + ) + + np.testing.assert_allclose(plus_projector.matrix(), [[0.5, 0.5], [0.5, 0.5]]) + + +def test_ket_bra_matrix_missing_qid(): + q0, q1 = cirq.LineQubit.range(2) + proj = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + + np.testing.assert_allclose(proj.matrix(), [[1.0, 0.0], [0.0, 0.0]]) + np.testing.assert_allclose(proj.matrix([q0]), [[1.0, 0.0], [0.0, 0.0]]) + np.testing.assert_allclose(proj.matrix([q1]), [[1.0, 0.0], [0.0, 1.0]]) + + np.testing.assert_allclose(proj.matrix([q0, q1]), np.diag([1.0, 1.0, 0.0, 0.0])) + np.testing.assert_allclose(proj.matrix([q1, q0]), np.diag([1.0, 0.0, 1.0, 0.0])) + + +def test_ket_bra_from_state_missing_qid(): + q0 = cirq.NamedQubit('q0') + q1 = cirq.NamedQubit('q1') + + d = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + + with pytest.raises(ValueError, match="Missing qid: q0"): + d.expectation_from_state_vector(np.array([[0.0, 0.0]]), qid_map={q1: 0}) + + with pytest.raises(ValueError, match="Missing qid: q0"): + d.expectation_from_density_matrix(np.array([[0.0, 0.0], [0.0, 0.0]]), qid_map={q1: 0}) + + +def test_ket_bra_from_state_missing_proj_key(): + q0 = cirq.NamedQubit('q0') + q1 = cirq.NamedQubit('q1') + + d = cirq.KetBraSum({(q0, q1): cirq.KetBra(0, 0)}) + + with pytest.raises(ValueError, match="Missing qid: q0"): + d.expectation_from_state_vector(np.array([[0.0, 0.0]]), qid_map={q1: 0}) + + with pytest.raises(ValueError, match="Missing qid: q0"): + d.expectation_from_density_matrix(np.array([[0.0, 0.0], [0.0, 0.0]]), qid_map={q1: 0}) + + +def test_equality(): + q0 = cirq.NamedQubit('q0') + + obj1 = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + obj2 = cirq.KetBraSum({q0: cirq.KetBra(1, 1)}) + + assert obj1 == obj1 + assert obj1 != obj2 + assert hash(obj1) == hash(obj1) + assert hash(obj1) != hash(obj2) + + +def test_ket_bra_qutrit(): + (q0,) = cirq.LineQid.range(1, dimension=3) + + zero_projector = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + one_projector = cirq.KetBraSum({q0: cirq.KetBra(1, 1)}) + two_projector = cirq.KetBraSum({q0: cirq.KetBra(2, 2)}) + + np.testing.assert_allclose( + zero_projector.matrix(), [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ) + + np.testing.assert_allclose( + one_projector.matrix(), [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]] + ) + + np.testing.assert_allclose( + two_projector.matrix(), [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]] + ) + + +def test_get_values(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + + assert len(d._ket_bra_dict_()) == 1 + assert np.allclose( + list(d._ket_bra_dict_()[q0]), + [ + ( + 0, + 0, + ) + ], + ) + + +def test_repr(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({q0: [[1.0, 0.0]]}) + + assert d.__repr__() == ("cirq.KetBraSum(ket_bra_dict={cirq.NamedQubit('q0'): [[1.0, 0.0]]})") + + +def test_consistency_with_existing(): + a, b = cirq.LineQubit.range(2) + mx = (cirq.KET_IMAG(a) * cirq.KET_IMAG(b)).projector() + proj_vec = np.asarray([0.5, 0.5j, 0.5j, -0.5]) + ii_proj = cirq.KetBraSum({(a, b): cirq.KetBra(proj_vec, proj_vec)}) + np.testing.assert_allclose(mx, ii_proj.matrix()) + + +def test_expectation_from_state_vector_basis_states_empty(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({}) + + np.testing.assert_allclose(d.expectation_from_state_vector(np.array([1.0, 0.0]), {q0: 0}), 1.0) + + +def test_expectation_from_state_vector_basis_states_single_qubits(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + + np.testing.assert_allclose(d.expectation_from_state_vector(np.array([1.0, 0.0]), {q0: 0}), 1.0) + np.testing.assert_allclose(d.expectation_from_state_vector(np.array([0.0, 1.0]), {q0: 0}), 0.0) + + +def test_expectation_from_state_vector_basis_states_three_qubits(): + q0 = cirq.NamedQubit('q0') + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + d = cirq.KetBraSum({q0: cirq.KetBra(0, 0), q1: cirq.KetBra(1, 1)}) + + np.testing.assert_allclose( + d.expectation_from_state_vector( + np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), {q0: 0, q1: 1, q2: 2} + ), + 0.0, + ) + + np.testing.assert_allclose( + d.expectation_from_state_vector( + np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), {q0: 0, q1: 2, q2: 1} + ), + 1.0, + ) + + +def test_expectation_higher_dims(): + q0 = cirq.NamedQid('q0', dimension=2) + q1 = cirq.NamedQid('q1', dimension=3) + q2 = cirq.NamedQid('q2', dimension=5) + d = cirq.KetBraSum({q2: cirq.KetBra(3, 3), q1: cirq.KetBra(1, 1)}) + + phis = [[1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0]] + + for perm in permutations([0, 1, 2]): + inv_perm = [-1] * len(perm) + for i, j in enumerate(perm): + inv_perm[j] = i + + state_vector = np.kron(phis[perm[0]], np.kron(phis[perm[1]], phis[perm[2]])) + state = np.einsum('i,j->ij', state_vector, state_vector.T.conj()) + + np.testing.assert_allclose( + d.expectation_from_state_vector( + state_vector, {q0: inv_perm[0], q1: inv_perm[1], q2: inv_perm[2]} + ), + 1.0, + ) + + np.testing.assert_allclose( + d.expectation_from_density_matrix( + state, {q0: inv_perm[0], q1: inv_perm[1], q2: inv_perm[2]} + ), + 1.0, + ) + + +def test_expectation_from_density_matrix_basis_states_empty(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({}) + + np.testing.assert_allclose( + d.expectation_from_density_matrix(np.array([[1.0, 0.0], [0.0, 0.0]]), {q0: 0}), 1.0 + ) + + +def test_expectation_from_density_matrix_basis_states_single_qubits(): + q0 = cirq.NamedQubit('q0') + d = cirq.KetBraSum({q0: cirq.KetBra(0, 0)}) + + np.testing.assert_allclose( + d.expectation_from_density_matrix(np.array([[1.0, 0.0], [0.0, 0.0]]), {q0: 0}), 1.0 + ) + np.testing.assert_allclose( + d.expectation_from_density_matrix(np.array([[0.0, 0.0], [0.0, 1.0]]), {q0: 0}), 0.0 + ) + + +def test_internal_consistency(): + q0 = cirq.NamedQid('q0', dimension=2) + q1 = cirq.NamedQid('q1', dimension=3) + + phi0 = np.asarray([1.0, -3.0j]) + phi1 = np.asarray([-0.5j, 1.0 + 2.0j, 1.2]) + + state_vector = np.asarray([1.0, 2.0j, -3.0, -4.0j, -5.0j, 0.0]) + + phi0 = phi0 / np.linalg.norm(phi0) + phi1 = phi1 / np.linalg.norm(phi1) + state_vector = state_vector / np.linalg.norm(state_vector) + state = np.einsum('i,j->ij', state_vector, state_vector.T.conj()) + + d = cirq.KetBraSum({q0: cirq.KetBra(phi0, phi0), q1: cirq.KetBra(phi1, phi1)}) + P = d.matrix(ket_bra_keys=[q1, q0]) + + projected_state = np.matmul(P, state_vector) + actual0 = np.linalg.norm(projected_state, ord=2) ** 2 + + actual1 = d.expectation_from_state_vector(state_vector, qid_map={q1: 0, q0: 1}) + + actual2 = d.expectation_from_density_matrix(state, qid_map={q1: 0, q0: 1}) + + np.testing.assert_allclose(actual0, actual1, atol=1e-6) + np.testing.assert_allclose(actual0, actual2, atol=1e-6) + + +def test_ket_bra_split_qubits(): + q0, q1, q2 = cirq.LineQubit.range(3) + phi = np.asarray([1.0 / math.sqrt(2), 0.0, 0.0, 1.0 / math.sqrt(2)]) + d = cirq.KetBraSum({(q0, q2): cirq.KetBra(phi, phi)}) + + qid_map = {q0: 0, q1: 1, q2: 2} + + state_vector = np.asarray([0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.5]) + state = np.einsum('i,j->ij', state_vector, state_vector.T.conj()) + + actual1 = d.expectation_from_state_vector(state_vector, qid_map=qid_map) + actual2 = d.expectation_from_density_matrix(state, qid_map=qid_map) + + np.testing.assert_allclose(actual1, 0.25, atol=1e-6) + np.testing.assert_allclose(actual2, 0.25, atol=1e-6) diff --git a/cirq-core/cirq/protocols/json_test_data/KetBra.json b/cirq-core/cirq/protocols/json_test_data/KetBra.json new file mode 100644 index 00000000000..45e7ca1c276 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KetBra.json @@ -0,0 +1,4 @@ +{ + "cirq_type": "KetBra", + "ket_bra_list": [[1, 2]] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KetBra.repr b/cirq-core/cirq/protocols/json_test_data/KetBra.repr new file mode 100644 index 00000000000..e1338048552 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KetBra.repr @@ -0,0 +1 @@ +cirq.KetBra(1, 2) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KetBraSum.json b/cirq-core/cirq/protocols/json_test_data/KetBraSum.json new file mode 100644 index 00000000000..234a79e1d46 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KetBraSum.json @@ -0,0 +1,23 @@ +{ + "cirq_type": "KetBraSum", + "ket_bra_dict": [ + [ + { + "cirq_type": "NamedQubit", + "name": "q0" + }, + [ + [ + [ + 1.0, + 0.0 + ], + [ + 0.0, + 1.0 + ] + ] + ] + ] + ] +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KetBraSum.repr b/cirq-core/cirq/protocols/json_test_data/KetBraSum.repr new file mode 100644 index 00000000000..923b394df40 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KetBraSum.repr @@ -0,0 +1 @@ +cirq.KetBraSum(ket_bra_dict={cirq.NamedQubit('q0'): [([1.0, 0.0], [0.0, 1.0])]}) \ No newline at end of file