From ad6e649419d10bd814b80220f94ea4f9ff665dde Mon Sep 17 00:00:00 2001 From: Victory Omole Date: Tue, 4 Jul 2023 18:31:28 -0500 Subject: [PATCH] Handle qubits in the __str__ of StateVectorTrialResult (#6180) --- cirq-core/cirq/sim/state_vector_simulator.py | 5 +++-- .../cirq/sim/state_vector_simulator_test.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index c6bcebe9a58..053ac2b0a48 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -20,6 +20,7 @@ from cirq import _compat, ops, value, qis from cirq.sim import simulator, state_vector, simulator_base +from cirq.protocols import qid_shape if TYPE_CHECKING: import cirq @@ -31,7 +32,7 @@ class SimulatesIntermediateStateVector( Generic[TStateVectorStepResult], simulator_base.SimulatorBase[ - TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState', + TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState' ], simulator.SimulatesAmplitudes, metaclass=abc.ABCMeta, @@ -172,7 +173,7 @@ def __str__(self) -> str: size = np.prod(shape, dtype=np.int64) final = final.reshape(size) if len([1 for e in final if abs(e) > 0.001]) < 16: - state_vector = qis.dirac_notation(final, 3) + state_vector = qis.dirac_notation(final, 3, qid_shape(substate.qubits)) else: state_vector = str(final) label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:' diff --git a/cirq-core/cirq/sim/state_vector_simulator_test.py b/cirq-core/cirq/sim/state_vector_simulator_test.py index c1c35a91fb3..fdd9d2086cd 100644 --- a/cirq-core/cirq/sim/state_vector_simulator_test.py +++ b/cirq-core/cirq/sim/state_vector_simulator_test.py @@ -159,6 +159,28 @@ def test_str_big(): assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result) +def test_str_qudit(): + qutrit = cirq.LineQid(0, dimension=3) + final_simulator_state = cirq.StateVectorSimulationState( + prng=np.random.RandomState(0), + qubits=[qutrit], + initial_state=np.array([0, 0, 1]), + dtype=np.complex64, + ) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) + assert "|2⟩" in str(result) + + ququart = cirq.LineQid(0, dimension=4) + final_simulator_state = cirq.StateVectorSimulationState( + prng=np.random.RandomState(0), + qubits=[ququart], + initial_state=np.array([0, 1, 0, 0]), + dtype=np.complex64, + ) + result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state) + assert "|1⟩" in str(result) + + def test_pretty_print(): final_simulator_state = cirq.StateVectorSimulationState( available_buffer=np.array([1]),