Skip to content

Commit 50eba00

Browse files
authored
Add TensorProductState, a state-vector-like object (#3171)
This is factored out of #2781 with @Strilanc 's initial comments addressed, per comments on that PR. `cirq.STATE_VECTOR_LIKE` accepts basis-state indices, sequences of per-qudit basis indices, or arrays of amplitudes. This PR extends this union type to accept `TensorProductState` -- an object representing a state vector which can be decomposed into a tensor product of single-qubit states. The structure shares a lot of similarity with `PauliString`. Whereas PauliString is a dictionary from qubits to (single-qubit) `Pauli` objects; here it is a mapping from qubits to single-qubit named states. Here we include `|0>, |1>, |+>, |->, |i>, |-i>`. As an example: ```python # start from the |+++> state q = cirq.LineQubit.range(3) initial_state = cirq.KET_PLUS(q[0]) * cirq.KET_PLUS(q[1]) * cirq.KET_PLUS(q[2]) ``` Since most things already use `STATE_VECTOR_LIKE` and its associated helper function `to_valid_state_vector`, only minor modifications to this function was necessary to enable specification of a tensor product state wherever initial states are sold.
1 parent b0f42a1 commit 50eba00

38 files changed

+820
-41
lines changed

cirq/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,14 @@
413413
TParamVal,
414414
validate_probability,
415415
value_equality,
416+
KET_PLUS,
417+
KET_MINUS,
418+
KET_IMAG,
419+
KET_MINUS_IMAG,
420+
KET_ZERO,
421+
KET_ONE,
422+
PAULI_STATES,
423+
ProductState,
416424
)
417425

418426
# pylint: disable=redefined-builtin

cirq/circuits/circuit_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,28 @@ def test_apply_unitary_effect_to_state():
24112411
np.array([0, 0, 1, 0]),
24122412
atol=1e-8)
24132413

2414+
# Product state
2415+
cirq.testing.assert_allclose_up_to_global_phase(cirq.Circuit(cirq.CNOT(
2416+
a, b)).final_state_vector(initial_state=cirq.KET_ZERO(a) *
2417+
cirq.KET_ZERO(b)),
2418+
np.array([1, 0, 0, 0]),
2419+
atol=1e-8)
2420+
cirq.testing.assert_allclose_up_to_global_phase(cirq.Circuit(cirq.CNOT(
2421+
a, b)).final_state_vector(initial_state=cirq.KET_ZERO(a) *
2422+
cirq.KET_ONE(b)),
2423+
np.array([0, 1, 0, 0]),
2424+
atol=1e-8)
2425+
cirq.testing.assert_allclose_up_to_global_phase(cirq.Circuit(cirq.CNOT(
2426+
a, b)).final_state_vector(initial_state=cirq.KET_ONE(a) *
2427+
cirq.KET_ZERO(b)),
2428+
np.array([0, 0, 0, 1]),
2429+
atol=1e-8)
2430+
cirq.testing.assert_allclose_up_to_global_phase(cirq.Circuit(cirq.CNOT(
2431+
a,
2432+
b)).final_state_vector(initial_state=cirq.KET_ONE(a) * cirq.KET_ONE(b)),
2433+
np.array([0, 0, 1, 0]),
2434+
atol=1e-8)
2435+
24142436
# Measurements.
24152437
cirq.testing.assert_allclose_up_to_global_phase(cirq.Circuit(
24162438
cirq.measure(a)).final_state_vector(),

cirq/ops/pauli_gates.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import abc
15-
from typing import Any, cast, Tuple, TYPE_CHECKING, Union
15+
from typing import Any, cast, Tuple, TYPE_CHECKING, Union, Dict
1616

17-
from cirq import value
1817
from cirq._doc import document
1918
from cirq.ops import common_gates, raw_types, identity
2019
from cirq.type_workarounds import NotImplementedType
@@ -23,6 +22,8 @@
2322
if TYPE_CHECKING:
2423
import cirq
2524
from cirq.ops.pauli_string import SingleQubitPauliStringGateOperation
25+
from cirq.value.product_state import (_XEigenState, _YEigenState,
26+
_ZEigenState) # coverage: ignore
2627

2728

2829
class Pauli(raw_types.Gate, metaclass=abc.ABCMeta):
@@ -98,19 +99,18 @@ def _canonical_exponent(self):
9899
"""Overrides EigenGate._canonical_exponent in subclasses."""
99100
return 1
100101

101-
102102
class _PauliX(Pauli, common_gates.XPowGate):
103103

104104
def __init__(self):
105105
Pauli.__init__(self, index=0, name='X')
106106
common_gates.XPowGate.__init__(self, exponent=1.0)
107107

108108
def __pow__(self: '_PauliX',
109-
exponent: value.TParamVal) -> common_gates.XPowGate:
109+
exponent: 'cirq.TParamVal') -> common_gates.XPowGate:
110110
return common_gates.XPowGate(exponent=exponent)
111111

112112
def _with_exponent(self: '_PauliX',
113-
exponent: value.TParamVal) -> common_gates.XPowGate:
113+
exponent: 'cirq.TParamVal') -> common_gates.XPowGate:
114114
return self.__pow__(exponent)
115115

116116
@classmethod
@@ -119,6 +119,14 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs):
119119
assert exponent == 1
120120
return Pauli._XYZ[0]
121121

122+
@property
123+
def basis(self: '_PauliX') -> Dict[int, '_XEigenState']:
124+
from cirq.value.product_state import _XEigenState
125+
return {
126+
+1: _XEigenState(+1),
127+
-1: _XEigenState(-1),
128+
}
129+
122130

123131
class _PauliY(Pauli, common_gates.YPowGate):
124132

@@ -127,11 +135,11 @@ def __init__(self):
127135
common_gates.YPowGate.__init__(self, exponent=1.0)
128136

129137
def __pow__(self: '_PauliY',
130-
exponent: value.TParamVal) -> common_gates.YPowGate:
138+
exponent: 'cirq.TParamVal') -> common_gates.YPowGate:
131139
return common_gates.YPowGate(exponent=exponent)
132140

133141
def _with_exponent(self: '_PauliY',
134-
exponent: value.TParamVal) -> common_gates.YPowGate:
142+
exponent: 'cirq.TParamVal') -> common_gates.YPowGate:
135143
return self.__pow__(exponent)
136144

137145
@classmethod
@@ -140,6 +148,14 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs):
140148
assert exponent == 1
141149
return Pauli._XYZ[1]
142150

151+
@property
152+
def basis(self: '_PauliY') -> Dict[int, '_YEigenState']:
153+
from cirq.value.product_state import _YEigenState
154+
return {
155+
+1: _YEigenState(+1),
156+
-1: _YEigenState(-1),
157+
}
158+
143159

144160
class _PauliZ(Pauli, common_gates.ZPowGate):
145161

@@ -148,11 +164,11 @@ def __init__(self):
148164
common_gates.ZPowGate.__init__(self, exponent=1.0)
149165

150166
def __pow__(self: '_PauliZ',
151-
exponent: value.TParamVal) -> common_gates.ZPowGate:
167+
exponent: 'cirq.TParamVal') -> common_gates.ZPowGate:
152168
return common_gates.ZPowGate(exponent=exponent)
153169

154170
def _with_exponent(self: '_PauliZ',
155-
exponent: value.TParamVal) -> common_gates.ZPowGate:
171+
exponent: 'cirq.TParamVal') -> common_gates.ZPowGate:
156172
return self.__pow__(exponent)
157173

158174
@classmethod
@@ -161,6 +177,14 @@ def _from_json_dict_(cls, exponent, global_shift, **kwargs):
161177
assert exponent == 1
162178
return Pauli._XYZ[2]
163179

180+
@property
181+
def basis(self: '_PauliZ') -> Dict[int, '_ZEigenState']:
182+
from cirq.value.product_state import _ZEigenState
183+
return {
184+
+1: _ZEigenState(+1),
185+
-1: _ZEigenState(-1),
186+
}
187+
164188

165189
X = _PauliX()
166190
document(

cirq/ops/pauli_string_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,31 @@ def test_to_z_basis_ops():
592592
z_basis_state, expected_state, rtol=1e-7, atol=1e-7)
593593

594594

595+
def test_to_z_basis_ops_product_state():
596+
q0, q1, q2, q3, q4, q5 = _make_qubits(6)
597+
pauli_string = cirq.PauliString({
598+
q0: cirq.X,
599+
q1: cirq.X,
600+
q2: cirq.Y,
601+
q3: cirq.Y,
602+
q4: cirq.Z,
603+
q5: cirq.Z
604+
})
605+
circuit = cirq.Circuit(pauli_string.to_z_basis_ops())
606+
607+
initial_state = cirq.KET_PLUS(q0) * cirq.KET_MINUS(q1) * cirq.KET_IMAG(
608+
q2) * cirq.KET_MINUS_IMAG(q3) * cirq.KET_ZERO(q4) * cirq.KET_ONE(q5)
609+
z_basis_state = circuit.final_state_vector(initial_state)
610+
611+
expected_state = np.zeros(2**6)
612+
expected_state[0b010101] = 1
613+
614+
cirq.testing.assert_allclose_up_to_global_phase(z_basis_state,
615+
expected_state,
616+
rtol=1e-7,
617+
atol=1e-7)
618+
619+
595620
def _assert_pass_over(ops: List[cirq.Operation],
596621
before: cirq.PauliString,
597622
after: cirq.PauliString):

cirq/protocols/json_serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def two_qubit_matrix_gate(matrix):
8080
'AsymmetricDepolarizingChannel':
8181
cirq.AsymmetricDepolarizingChannel,
8282
'BitFlipChannel': cirq.BitFlipChannel,
83+
'ProductState': cirq.ProductState,
8384
'CCNotPowGate': cirq.CCNotPowGate,
8485
'CCXPowGate': cirq.CCXPowGate,
8586
'CCZPowGate': cirq.CCZPowGate,
@@ -120,6 +121,12 @@ def two_qubit_matrix_gate(matrix):
120121
'MatrixGate': cirq.MatrixGate,
121122
'MeasurementGate': cirq.MeasurementGate,
122123
'Moment': cirq.Moment,
124+
'_XEigenState':
125+
cirq.value.product_state._XEigenState, # type: ignore
126+
'_YEigenState':
127+
cirq.value.product_state._YEigenState, # type: ignore
128+
'_ZEigenState':
129+
cirq.value.product_state._ZEigenState, # type: ignore
123130
'_NamedConstantXmonDevice': _NamedConstantXmonDevice,
124131
'_NoNoiseModel': _NoNoiseModel,
125132
'NamedQubit': cirq.NamedQubit,

cirq/protocols/json_serialization_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def test_fail_to_resolve():
138138
# global objects
139139
'CONTROL_TAG',
140140
'PAULI_BASIS',
141+
'PAULI_STATES',
141142

142143
# abstract, but not inspect.isabstract():
143144
'Device',
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"cirq_type": "_YEigenState",
3+
"eigenvalue": 1
4+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.KET_IMAG
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"cirq_type": "_XEigenState",
3+
"eigenvalue": -1
4+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cirq.KET_MINUS

0 commit comments

Comments
 (0)